From 405b3bbfdb3e3d963fcc344d9a9e2d333db7d492 Mon Sep 17 00:00:00 2001 From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com> Date: Wed, 23 Nov 2022 17:39:03 -0800 Subject: [PATCH] Fix default auth and 'AzureTenantId' to persist and not reset on selection event (#21289) --- .../providers/database/databaseTreeDataProvider.ts | 5 +++-- .../databaseServer/databaseServerTreeDataProvider.ts | 5 ++--- .../providers/resourceTreeDataProviderBase.ts | 4 ++-- extensions/azurecore/src/azureResource/resourceService.ts | 4 ++-- .../src/azureResource/tree/flatAccountTreeNode.ts | 4 ++-- .../azurecore/src/azureResource/tree/flatTreeProvider.ts | 2 +- .../src/azureResource/tree/subscriptionTreeNode.ts | 2 +- .../providers/database/databaseTreeDataProvider.test.ts | 3 +++ .../databaseServer/databaseServerTreeDataProvider.test.ts | 3 +++ .../src/test/azureResource/resourceService.test.ts | 8 ++++---- .../services/connection/browser/connectionWidget.ts | 4 +--- 11 files changed, 24 insertions(+), 20 deletions(-) diff --git a/extensions/azurecore/src/azureResource/providers/database/databaseTreeDataProvider.ts b/extensions/azurecore/src/azureResource/providers/database/databaseTreeDataProvider.ts index 48793cfaf1..d8b5bf04a8 100644 --- a/extensions/azurecore/src/azureResource/providers/database/databaseTreeDataProvider.ts +++ b/extensions/azurecore/src/azureResource/providers/database/databaseTreeDataProvider.ts @@ -3,7 +3,7 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import { connection, TreeItem, ExtensionNodeType } from 'azdata'; +import { TreeItem, ExtensionNodeType } from 'azdata'; import * as vscode from 'vscode'; import * as nls from 'vscode-nls'; const localize = nls.loadMessageBundle(); @@ -25,6 +25,7 @@ export class AzureResourceDatabaseTreeDataProvider extends ResourceTreeDataProvi ) { super(databaseService); } + protected getTreeItemForResource(database: azureResource.AzureResourceDatabase, account: AzureAccount): TreeItem { return { id: `databaseServer_${database.serverFullName}.database_${database.name}`, @@ -42,7 +43,7 @@ export class AzureResourceDatabaseTreeDataProvider extends ResourceTreeDataProvi databaseName: database.name, userName: database.loginName, password: '', - authenticationType: connection.AuthenticationType.AzureMFA, + authenticationType: '', savePassword: true, groupFullName: '', groupId: '', diff --git a/extensions/azurecore/src/azureResource/providers/databaseServer/databaseServerTreeDataProvider.ts b/extensions/azurecore/src/azureResource/providers/databaseServer/databaseServerTreeDataProvider.ts index 3464d1d592..db7f49105a 100644 --- a/extensions/azurecore/src/azureResource/providers/databaseServer/databaseServerTreeDataProvider.ts +++ b/extensions/azurecore/src/azureResource/providers/databaseServer/databaseServerTreeDataProvider.ts @@ -3,7 +3,7 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import { connection, ExtensionNodeType, TreeItem } from 'azdata'; +import { ExtensionNodeType, TreeItem } from 'azdata'; import * as vscode from 'vscode'; import * as nls from 'vscode-nls'; const localize = nls.loadMessageBundle(); @@ -25,7 +25,6 @@ export class AzureResourceDatabaseServerTreeDataProvider extends ResourceTreeDat super(databaseServerService); } - protected getTreeItemForResource(databaseServer: azureResource.AzureResourceDatabaseServer, account: AzureAccount): TreeItem { return { id: `databaseServer_${databaseServer.id ? databaseServer.id : databaseServer.name}`, @@ -43,7 +42,7 @@ export class AzureResourceDatabaseServerTreeDataProvider extends ResourceTreeDat databaseName: databaseServer.defaultDatabaseName, userName: databaseServer.loginName, password: '', - authenticationType: connection.AuthenticationType.AzureMFA, + authenticationType: '', savePassword: true, groupFullName: '', groupId: '', diff --git a/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts b/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts index 5dfb369bac..ae1e0d611d 100644 --- a/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts +++ b/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts @@ -28,7 +28,7 @@ export abstract class ResourceTreeDataProviderBase { account: element.account, subscription: element.subscription, - tenantId: element.tenantId, + tenantId: element.subscription.tenant, treeItem: this.getTreeItemForResource(resource, element.account) }).sort((a, b) => (a.treeItem.label).localeCompare(b.treeItem.label)); } catch (error) { @@ -38,7 +38,7 @@ export abstract class ResourceTreeDataProviderBase { - const response = await azdata.accounts.getAccountSecurityToken(element.account, element.tenantId, azdata.AzureResource.ResourceManagement); + const response = await azdata.accounts.getAccountSecurityToken(element.account, element.subscription.tenant!, azdata.AzureResource.ResourceManagement); if (!response) { throw new Error(`Did not receive security token when getting resources for account ${element.account.displayInfo.displayName}`); } diff --git a/extensions/azurecore/src/azureResource/resourceService.ts b/extensions/azurecore/src/azureResource/resourceService.ts index 3bab08cedb..20ea653075 100644 --- a/extensions/azurecore/src/azureResource/resourceService.ts +++ b/extensions/azurecore/src/azureResource/resourceService.ts @@ -32,7 +32,7 @@ export class AzureResourceService { this._areResourceProvidersLoaded = false; } - public async getRootChildren(resourceProviderId: string, account: AzureAccount, subscription: azureResource.AzureResourceSubscription, tenantId: string): Promise { + public async getRootChildren(resourceProviderId: string, account: AzureAccount, subscription: azureResource.AzureResourceSubscription): Promise { await this.ensureResourceProvidersRegistered(); if (!(resourceProviderId in this._resourceProviders)) { @@ -48,7 +48,7 @@ export class AzureResourceService { resourceNode: { account, subscription, - tenantId, + tenantId: subscription.tenant!, treeItem: rootChild } }; diff --git a/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts b/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts index aa0c24fbda..5bd75d955b 100644 --- a/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts +++ b/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts @@ -206,10 +206,10 @@ class FlatAccountTreeNodeLoader { const resourceProviderIds = await this._resourceService.listResourceProviderIds(); for (const subscription of subscriptions) { for (const providerId of resourceProviderIds) { - const resourceTypes = await this._resourceService.getRootChildren(providerId, this._account, subscription, subscription.tenant!); + const resourceTypes = await this._resourceService.getRootChildren(providerId, this._account, subscription); for (const resourceType of resourceTypes) { const resources = await this._resourceService.getChildren(providerId, resourceType.resourceNode, true); - if (resources.length > 0) { + if (resources?.length > 0) { this._nodes.push(...resources.map(dr => new AzureResourceResourceTreeNode(dr, this._accountNode, this.appContext))); this._nodes = this.nodes.sort((a, b) => { return a.getNodeInfo().label.localeCompare(b.getNodeInfo().label); diff --git a/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts b/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts index e7a26e5380..37f9fd83ca 100644 --- a/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts +++ b/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts @@ -126,7 +126,7 @@ class ResourceLoader { for (const tenant of account.properties.tenants) { for (const subscription of await this.subscriptionService.getSubscriptions(account, [tenant.id])) { for (const providerId of await this.resourceService.listResourceProviderIds()) { - for (const group of await this.resourceService.getRootChildren(providerId, account, subscription, subscription.tenant!)) { + for (const group of await this.resourceService.getRootChildren(providerId, account, subscription)) { const children = await this.resourceService.getChildren(providerId, group.resourceNode); let groupNode: AzureResourceResourceTreeNode | undefined = this.resourceGroups.get(group.resourceProviderId); if (groupNode) { diff --git a/extensions/azurecore/src/azureResource/tree/subscriptionTreeNode.ts b/extensions/azurecore/src/azureResource/tree/subscriptionTreeNode.ts index 06e49e422f..fbbc97cf44 100644 --- a/extensions/azurecore/src/azureResource/tree/subscriptionTreeNode.ts +++ b/extensions/azurecore/src/azureResource/tree/subscriptionTreeNode.ts @@ -42,7 +42,7 @@ export class AzureResourceSubscriptionTreeNode extends AzureResourceContainerTre const children: IAzureResourceNodeWithProviderId[] = []; for (const resourceProviderId of await resourceService.listResourceProviderIds()) { - children.push(...await resourceService.getRootChildren(resourceProviderId, this.account, this.subscription, this.tenantId)); + children.push(...await resourceService.getRootChildren(resourceProviderId, this.account, this.subscription)); } if (children.length === 0) { diff --git a/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts b/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts index 28cb0ffb9b..692ec7bf62 100644 --- a/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts +++ b/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts @@ -164,6 +164,9 @@ describe('AzureResourceDatabaseTreeDataProvider.getChildren', function (): void should(child.treeItem.label).equal(`${database.name} (${database.serverName})`); should(child.treeItem.collapsibleState).equal(vscode.TreeItemCollapsibleState.Collapsed); should(child.treeItem.contextValue).equal(AzureResourceItemType.database); + + // Authentication type should be empty string by default to support setting 'Sql: Default Authentication Type'. + should(child.treeItem.payload!.authenticationType).equal(''); } }); }); diff --git a/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts b/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts index ce812eb1c7..0d0f826a90 100644 --- a/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts +++ b/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts @@ -163,6 +163,9 @@ describe('AzureResourceDatabaseServerTreeDataProvider.getChildren', function (): should(child.treeItem.label).equal(databaseServer.name); should(child.treeItem.collapsibleState).equal(vscode.TreeItemCollapsibleState.Collapsed); should(child.treeItem.contextValue).equal(AzureResourceItemType.databaseServer); + + // Authentication type should be empty string by default to support setting 'Sql: Default Authentication Type'. + should(child.treeItem.payload!.authenticationType).equal(''); } }); }); diff --git a/extensions/azurecore/src/test/azureResource/resourceService.test.ts b/extensions/azurecore/src/test/azureResource/resourceService.test.ts index 4e1be378a5..ca286aed07 100644 --- a/extensions/azurecore/src/test/azureResource/resourceService.test.ts +++ b/extensions/azurecore/src/test/azureResource/resourceService.test.ts @@ -106,7 +106,7 @@ describe('AzureResourceService.getRootChildren', function (): void { }); it('Should be correct when provider id is correct.', async function (): Promise { - const children = await resourceService.getRootChildren(mockResourceProvider1.object.providerId, mockAccount, mockSubscription, mockTenantId); + const children = await resourceService.getRootChildren(mockResourceProvider1.object.providerId, mockAccount, mockSubscription); should(children).Array(); }); @@ -114,7 +114,7 @@ describe('AzureResourceService.getRootChildren', function (): void { it('Should throw exceptions when provider id is incorrect.', async function (): Promise { const providerId = 'non_existent_provider_id'; try { - await resourceService.getRootChildren(providerId, mockAccount, mockSubscription, mockTenantId); + await resourceService.getRootChildren(providerId, mockAccount, mockSubscription); } catch (error) { should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`); return; @@ -147,7 +147,7 @@ describe('AzureResourceService.getChildren', function (): void { it('Should throw exceptions when provider id is incorrect.', async function (): Promise { const providerId = 'non_existent_provider_id'; try { - await resourceService.getRootChildren(providerId, mockAccount, mockSubscription, mockTenantId); + await resourceService.getRootChildren(providerId, mockAccount, mockSubscription); } catch (error) { should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`); return; @@ -180,7 +180,7 @@ describe('AzureResourceService.getTreeItem', function (): void { it('Should throw exceptions when provider id is incorrect.', async function (): Promise { const providerId = 'non_existent_provider_id'; try { - await resourceService.getRootChildren(providerId, mockAccount, mockSubscription, mockTenantId); + await resourceService.getRootChildren(providerId, mockAccount, mockSubscription); } catch (error) { should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`); return; diff --git a/src/sql/workbench/services/connection/browser/connectionWidget.ts b/src/sql/workbench/services/connection/browser/connectionWidget.ts index f7d386cf81..94b4e55cd1 100644 --- a/src/sql/workbench/services/connection/browser/connectionWidget.ts +++ b/src/sql/workbench/services/connection/browser/connectionWidget.ts @@ -686,17 +686,15 @@ export class ConnectionWidget extends lifecycle.Disposable { } private onAzureTenantSelected(tenantIndex: number): void { - this._azureTenantId = undefined; let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); if (account && account.properties.tenants) { let tenant = account.properties.tenants[tenantIndex]; if (tenant) { - this._azureTenantId = tenant.id; this._callbacks.onAzureTenantSelection(tenant.id); } else { // This should ideally never ever happen! - this._logService.error(`onAzureTenantSelected : Could not find tenant with ID ${this._azureTenantId} for account ${account.displayInfo.displayName}`); + this._logService.error(`onAzureTenantSelected : Tenant list not found as expected, missing tenant on index ${tenantIndex}`); } } }