Fix default auth and 'AzureTenantId' to persist and not reset on selection event (#21289)

This commit is contained in:
Cheena Malhotra
2022-11-23 17:39:03 -08:00
committed by GitHub
parent 86c3f315f2
commit 405b3bbfdb
11 changed files with 24 additions and 20 deletions

View File

@@ -3,7 +3,7 @@
* Licensed under the Source EULA. See License.txt in the project root for license information. * Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/ *--------------------------------------------------------------------------------------------*/
import { connection, TreeItem, ExtensionNodeType } from 'azdata'; import { TreeItem, ExtensionNodeType } from 'azdata';
import * as vscode from 'vscode'; import * as vscode from 'vscode';
import * as nls from 'vscode-nls'; import * as nls from 'vscode-nls';
const localize = nls.loadMessageBundle(); const localize = nls.loadMessageBundle();
@@ -25,6 +25,7 @@ export class AzureResourceDatabaseTreeDataProvider extends ResourceTreeDataProvi
) { ) {
super(databaseService); super(databaseService);
} }
protected getTreeItemForResource(database: azureResource.AzureResourceDatabase, account: AzureAccount): TreeItem { protected getTreeItemForResource(database: azureResource.AzureResourceDatabase, account: AzureAccount): TreeItem {
return { return {
id: `databaseServer_${database.serverFullName}.database_${database.name}`, id: `databaseServer_${database.serverFullName}.database_${database.name}`,
@@ -42,7 +43,7 @@ export class AzureResourceDatabaseTreeDataProvider extends ResourceTreeDataProvi
databaseName: database.name, databaseName: database.name,
userName: database.loginName, userName: database.loginName,
password: '', password: '',
authenticationType: connection.AuthenticationType.AzureMFA, authenticationType: '',
savePassword: true, savePassword: true,
groupFullName: '', groupFullName: '',
groupId: '', groupId: '',

View File

@@ -3,7 +3,7 @@
* Licensed under the Source EULA. See License.txt in the project root for license information. * Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/ *--------------------------------------------------------------------------------------------*/
import { connection, ExtensionNodeType, TreeItem } from 'azdata'; import { ExtensionNodeType, TreeItem } from 'azdata';
import * as vscode from 'vscode'; import * as vscode from 'vscode';
import * as nls from 'vscode-nls'; import * as nls from 'vscode-nls';
const localize = nls.loadMessageBundle(); const localize = nls.loadMessageBundle();
@@ -25,7 +25,6 @@ export class AzureResourceDatabaseServerTreeDataProvider extends ResourceTreeDat
super(databaseServerService); super(databaseServerService);
} }
protected getTreeItemForResource(databaseServer: azureResource.AzureResourceDatabaseServer, account: AzureAccount): TreeItem { protected getTreeItemForResource(databaseServer: azureResource.AzureResourceDatabaseServer, account: AzureAccount): TreeItem {
return { return {
id: `databaseServer_${databaseServer.id ? databaseServer.id : databaseServer.name}`, id: `databaseServer_${databaseServer.id ? databaseServer.id : databaseServer.name}`,
@@ -43,7 +42,7 @@ export class AzureResourceDatabaseServerTreeDataProvider extends ResourceTreeDat
databaseName: databaseServer.defaultDatabaseName, databaseName: databaseServer.defaultDatabaseName,
userName: databaseServer.loginName, userName: databaseServer.loginName,
password: '', password: '',
authenticationType: connection.AuthenticationType.AzureMFA, authenticationType: '',
savePassword: true, savePassword: true,
groupFullName: '', groupFullName: '',
groupId: '', groupId: '',

View File

@@ -28,7 +28,7 @@ export abstract class ResourceTreeDataProviderBase<T extends azureResource.Azure
return resources.map((resource) => <azureResource.IAzureResourceNode>{ return resources.map((resource) => <azureResource.IAzureResourceNode>{
account: element.account, account: element.account,
subscription: element.subscription, subscription: element.subscription,
tenantId: element.tenantId, tenantId: element.subscription.tenant,
treeItem: this.getTreeItemForResource(resource, element.account) treeItem: this.getTreeItemForResource(resource, element.account)
}).sort((a, b) => (<any>a.treeItem.label).localeCompare(b.treeItem.label)); }).sort((a, b) => (<any>a.treeItem.label).localeCompare(b.treeItem.label));
} catch (error) { } catch (error) {
@@ -38,7 +38,7 @@ export abstract class ResourceTreeDataProviderBase<T extends azureResource.Azure
} }
private async getResources(element: azureResource.IAzureResourceNode): Promise<T[]> { private async getResources(element: azureResource.IAzureResourceNode): Promise<T[]> {
const response = await azdata.accounts.getAccountSecurityToken(element.account, element.tenantId, azdata.AzureResource.ResourceManagement); const response = await azdata.accounts.getAccountSecurityToken(element.account, element.subscription.tenant!, azdata.AzureResource.ResourceManagement);
if (!response) { if (!response) {
throw new Error(`Did not receive security token when getting resources for account ${element.account.displayInfo.displayName}`); throw new Error(`Did not receive security token when getting resources for account ${element.account.displayInfo.displayName}`);
} }

View File

@@ -32,7 +32,7 @@ export class AzureResourceService {
this._areResourceProvidersLoaded = false; this._areResourceProvidersLoaded = false;
} }
public async getRootChildren(resourceProviderId: string, account: AzureAccount, subscription: azureResource.AzureResourceSubscription, tenantId: string): Promise<IAzureResourceNodeWithProviderId[]> { public async getRootChildren(resourceProviderId: string, account: AzureAccount, subscription: azureResource.AzureResourceSubscription): Promise<IAzureResourceNodeWithProviderId[]> {
await this.ensureResourceProvidersRegistered(); await this.ensureResourceProvidersRegistered();
if (!(resourceProviderId in this._resourceProviders)) { if (!(resourceProviderId in this._resourceProviders)) {
@@ -48,7 +48,7 @@ export class AzureResourceService {
resourceNode: { resourceNode: {
account, account,
subscription, subscription,
tenantId, tenantId: subscription.tenant!,
treeItem: rootChild treeItem: rootChild
} }
}; };

View File

@@ -206,10 +206,10 @@ class FlatAccountTreeNodeLoader {
const resourceProviderIds = await this._resourceService.listResourceProviderIds(); const resourceProviderIds = await this._resourceService.listResourceProviderIds();
for (const subscription of subscriptions) { for (const subscription of subscriptions) {
for (const providerId of resourceProviderIds) { for (const providerId of resourceProviderIds) {
const resourceTypes = await this._resourceService.getRootChildren(providerId, this._account, subscription, subscription.tenant!); const resourceTypes = await this._resourceService.getRootChildren(providerId, this._account, subscription);
for (const resourceType of resourceTypes) { for (const resourceType of resourceTypes) {
const resources = await this._resourceService.getChildren(providerId, resourceType.resourceNode, true); const resources = await this._resourceService.getChildren(providerId, resourceType.resourceNode, true);
if (resources.length > 0) { if (resources?.length > 0) {
this._nodes.push(...resources.map(dr => new AzureResourceResourceTreeNode(dr, this._accountNode, this.appContext))); this._nodes.push(...resources.map(dr => new AzureResourceResourceTreeNode(dr, this._accountNode, this.appContext)));
this._nodes = this.nodes.sort((a, b) => { this._nodes = this.nodes.sort((a, b) => {
return a.getNodeInfo().label.localeCompare(b.getNodeInfo().label); return a.getNodeInfo().label.localeCompare(b.getNodeInfo().label);

View File

@@ -126,7 +126,7 @@ class ResourceLoader {
for (const tenant of account.properties.tenants) { for (const tenant of account.properties.tenants) {
for (const subscription of await this.subscriptionService.getSubscriptions(account, [tenant.id])) { for (const subscription of await this.subscriptionService.getSubscriptions(account, [tenant.id])) {
for (const providerId of await this.resourceService.listResourceProviderIds()) { for (const providerId of await this.resourceService.listResourceProviderIds()) {
for (const group of await this.resourceService.getRootChildren(providerId, account, subscription, subscription.tenant!)) { for (const group of await this.resourceService.getRootChildren(providerId, account, subscription)) {
const children = await this.resourceService.getChildren(providerId, group.resourceNode); const children = await this.resourceService.getChildren(providerId, group.resourceNode);
let groupNode: AzureResourceResourceTreeNode | undefined = this.resourceGroups.get(group.resourceProviderId); let groupNode: AzureResourceResourceTreeNode | undefined = this.resourceGroups.get(group.resourceProviderId);
if (groupNode) { if (groupNode) {

View File

@@ -42,7 +42,7 @@ export class AzureResourceSubscriptionTreeNode extends AzureResourceContainerTre
const children: IAzureResourceNodeWithProviderId[] = []; const children: IAzureResourceNodeWithProviderId[] = [];
for (const resourceProviderId of await resourceService.listResourceProviderIds()) { for (const resourceProviderId of await resourceService.listResourceProviderIds()) {
children.push(...await resourceService.getRootChildren(resourceProviderId, this.account, this.subscription, this.tenantId)); children.push(...await resourceService.getRootChildren(resourceProviderId, this.account, this.subscription));
} }
if (children.length === 0) { if (children.length === 0) {

View File

@@ -164,6 +164,9 @@ describe('AzureResourceDatabaseTreeDataProvider.getChildren', function (): void
should(child.treeItem.label).equal(`${database.name} (${database.serverName})`); should(child.treeItem.label).equal(`${database.name} (${database.serverName})`);
should(child.treeItem.collapsibleState).equal(vscode.TreeItemCollapsibleState.Collapsed); should(child.treeItem.collapsibleState).equal(vscode.TreeItemCollapsibleState.Collapsed);
should(child.treeItem.contextValue).equal(AzureResourceItemType.database); should(child.treeItem.contextValue).equal(AzureResourceItemType.database);
// Authentication type should be empty string by default to support setting 'Sql: Default Authentication Type'.
should(child.treeItem.payload!.authenticationType).equal('');
} }
}); });
}); });

View File

@@ -163,6 +163,9 @@ describe('AzureResourceDatabaseServerTreeDataProvider.getChildren', function ():
should(child.treeItem.label).equal(databaseServer.name); should(child.treeItem.label).equal(databaseServer.name);
should(child.treeItem.collapsibleState).equal(vscode.TreeItemCollapsibleState.Collapsed); should(child.treeItem.collapsibleState).equal(vscode.TreeItemCollapsibleState.Collapsed);
should(child.treeItem.contextValue).equal(AzureResourceItemType.databaseServer); should(child.treeItem.contextValue).equal(AzureResourceItemType.databaseServer);
// Authentication type should be empty string by default to support setting 'Sql: Default Authentication Type'.
should(child.treeItem.payload!.authenticationType).equal('');
} }
}); });
}); });

View File

@@ -106,7 +106,7 @@ describe('AzureResourceService.getRootChildren', function (): void {
}); });
it('Should be correct when provider id is correct.', async function (): Promise<void> { it('Should be correct when provider id is correct.', async function (): Promise<void> {
const children = await resourceService.getRootChildren(mockResourceProvider1.object.providerId, mockAccount, mockSubscription, mockTenantId); const children = await resourceService.getRootChildren(mockResourceProvider1.object.providerId, mockAccount, mockSubscription);
should(children).Array(); should(children).Array();
}); });
@@ -114,7 +114,7 @@ describe('AzureResourceService.getRootChildren', function (): void {
it('Should throw exceptions when provider id is incorrect.', async function (): Promise<void> { it('Should throw exceptions when provider id is incorrect.', async function (): Promise<void> {
const providerId = 'non_existent_provider_id'; const providerId = 'non_existent_provider_id';
try { try {
await resourceService.getRootChildren(providerId, mockAccount, mockSubscription, mockTenantId); await resourceService.getRootChildren(providerId, mockAccount, mockSubscription);
} catch (error) { } catch (error) {
should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`); should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`);
return; return;
@@ -147,7 +147,7 @@ describe('AzureResourceService.getChildren', function (): void {
it('Should throw exceptions when provider id is incorrect.', async function (): Promise<void> { it('Should throw exceptions when provider id is incorrect.', async function (): Promise<void> {
const providerId = 'non_existent_provider_id'; const providerId = 'non_existent_provider_id';
try { try {
await resourceService.getRootChildren(providerId, mockAccount, mockSubscription, mockTenantId); await resourceService.getRootChildren(providerId, mockAccount, mockSubscription);
} catch (error) { } catch (error) {
should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`); should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`);
return; return;
@@ -180,7 +180,7 @@ describe('AzureResourceService.getTreeItem', function (): void {
it('Should throw exceptions when provider id is incorrect.', async function (): Promise<void> { it('Should throw exceptions when provider id is incorrect.', async function (): Promise<void> {
const providerId = 'non_existent_provider_id'; const providerId = 'non_existent_provider_id';
try { try {
await resourceService.getRootChildren(providerId, mockAccount, mockSubscription, mockTenantId); await resourceService.getRootChildren(providerId, mockAccount, mockSubscription);
} catch (error) { } catch (error) {
should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`); should(error.message).equal(`Azure resource provider doesn't exist. Id: ${providerId}`);
return; return;

View File

@@ -686,17 +686,15 @@ export class ConnectionWidget extends lifecycle.Disposable {
} }
private onAzureTenantSelected(tenantIndex: number): void { private onAzureTenantSelected(tenantIndex: number): void {
this._azureTenantId = undefined;
let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value);
if (account && account.properties.tenants) { if (account && account.properties.tenants) {
let tenant = account.properties.tenants[tenantIndex]; let tenant = account.properties.tenants[tenantIndex];
if (tenant) { if (tenant) {
this._azureTenantId = tenant.id;
this._callbacks.onAzureTenantSelection(tenant.id); this._callbacks.onAzureTenantSelection(tenant.id);
} }
else { else {
// This should ideally never ever happen! // This should ideally never ever happen!
this._logService.error(`onAzureTenantSelected : Could not find tenant with ID ${this._azureTenantId} for account ${account.displayInfo.displayName}`); this._logService.error(`onAzureTenantSelected : Tenant list not found as expected, missing tenant on index ${tenantIndex}`);
} }
} }
} }