diff --git a/extensions/azurecore/src/azureDataGridProvider.ts b/extensions/azurecore/src/azureDataGridProvider.ts index 4f46a2fc97..42a706b06b 100644 --- a/extensions/azurecore/src/azureDataGridProvider.ts +++ b/extensions/azurecore/src/azureDataGridProvider.ts @@ -7,7 +7,6 @@ import * as azdata from 'azdata'; import { AppContext } from './appContext'; import { AzureResourceServiceNames } from './azureResource/constants'; import { IAzureResourceSubscriptionService } from './azureResource/interfaces'; -import { TokenCredentials } from '@azure/ms-rest-js'; import { azureResource } from 'azureResource'; import * as azureResourceUtils from './azureResource/utils'; import * as constants from './constants'; @@ -36,12 +35,8 @@ export class AzureDataGridProvider implements azdata.DataGridProvider { await Promise.all(accounts.map(async (account) => { await Promise.all(account.properties.tenants.map(async (tenant: { id: string; }) => { try { - const tokenResponse = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); - const token = tokenResponse.token; - const tokenType = tokenResponse.tokenType; - const credential = new TokenCredentials(token, tokenType); const subscriptionService = this._appContext.getService(AzureResourceServiceNames.subscriptionService); - const subscriptions = await subscriptionService.getSubscriptions(account, credential, tenant.id); + const subscriptions = await subscriptionService.getSubscriptions(account, [tenant.id]); try { const newItems = (await azureResourceUtils.runResourceQuery(account, subscriptions, true, `where ${typesClause}`)).resources .map(item => { diff --git a/extensions/azurecore/src/azureResource/commands.ts b/extensions/azurecore/src/azureResource/commands.ts index 622773ab48..72943b11f4 100644 --- a/extensions/azurecore/src/azureResource/commands.ts +++ b/extensions/azurecore/src/azureResource/commands.ts @@ -111,7 +111,7 @@ export function registerAzureResourceCommands(appContext: AppContext, azureViewT let subscriptions: azureResource.AzureResourceSubscription[] = []; if (subscriptions.length === 0) { try { - subscriptions = await subscriptionService.getAllSubscriptions(account); + subscriptions = await subscriptionService.getSubscriptions(account); } catch (error) { account.isStale = true; vscode.window.showErrorMessage(AzureResourceErrorMessageUtil.getErrorMessage(error)); diff --git a/extensions/azurecore/src/azureResource/interfaces.ts b/extensions/azurecore/src/azureResource/interfaces.ts index 4c58026dce..ea5a10a12c 100644 --- a/extensions/azurecore/src/azureResource/interfaces.ts +++ b/extensions/azurecore/src/azureResource/interfaces.ts @@ -11,8 +11,14 @@ import { azureResource } from 'azureResource'; import { AzureAccount, Tenant } from 'azurecore'; export interface IAzureResourceSubscriptionService { - getSubscriptions(account: Account, credential: msRest.ServiceClientCredentials, tenantId: string): Promise; - getAllSubscriptions(account: Account): Promise; + /** + * Gets subscriptions for the given account. Any errors that occur while fetching the subscriptions for each tenant + * will be displayed to the user, but this function will only throw an error if it's unable to fetch any subscriptions. + * @param account The account to get the subscriptions for + * @param tenants The list of tenants to get subscriptions for - if undefined then subscriptions for all tenants will be retrieved + * @returns The list of all subscriptions on this account that were able to be retrieved + */ + getSubscriptions(account: Account, tenants?: string[] | undefined): Promise; } export interface IAzureResourceSubscriptionFilterService { diff --git a/extensions/azurecore/src/azureResource/services/subscriptionService.ts b/extensions/azurecore/src/azureResource/services/subscriptionService.ts index 060e09b6d2..eb5c78b623 100644 --- a/extensions/azurecore/src/azureResource/services/subscriptionService.ts +++ b/extensions/azurecore/src/azureResource/services/subscriptionService.ts @@ -14,47 +14,33 @@ import { AzureSubscriptionError } from '../errors'; import { AzureResourceErrorMessageUtil } from '../utils'; import * as nls from 'vscode-nls'; -import { AzureAccount } from 'azurecore'; const localize = nls.loadMessageBundle(); export class AzureResourceSubscriptionService implements IAzureResourceSubscriptionService { /** - * Gets all of the subscriptions for the specified account using the specified credential. This assumes that the credential passed is for - * the specified tenant - which the subscriptions returned will be associated with. - * @param account The account to get the subscriptions for - * @param credential The credential to use for querying the subscriptions - * @param tenantId The ID of the tenant these subscriptions are for - * @returns The list of all subscriptions on this account for the specified tenant - */ - public async getSubscriptions(account: azdata.Account, credential: any, tenantId: string): Promise { - const subscriptions: azureResource.AzureResourceSubscription[] = []; - - const subClient = new SubscriptionClient(credential, { baseUri: account.properties.providerSettings.settings.armResource.endpoint }); - const subs = await subClient.subscriptions.list(); - subs.forEach((sub) => subscriptions.push({ - id: sub.subscriptionId, - name: sub.displayName, - tenant: tenantId - })); - - return subscriptions; - } - - /** - * Gets all subscriptions for all tenants of the given account. Any errors that occur while fetching the subscriptions for each tenant + * Gets subscriptions for the given account. Any errors that occur while fetching the subscriptions for each tenant * will be displayed to the user, but this function will only throw an error if it's unable to fetch any subscriptions. * @param account The account to get the subscriptions for + * @param tenants The list of tenants to get subscriptions for - if undefined then subscriptions for all tenants will be retrieved * @returns The list of all subscriptions on this account that were able to be retrieved */ - public async getAllSubscriptions(account: AzureAccount): Promise { + public async getSubscriptions(account: azdata.Account, tenants?: string[]): Promise { const subscriptions: azureResource.AzureResourceSubscription[] = []; let gotSubscriptions = false; const errors: Error[] = []; - for (const tenant of account.properties.tenants) { + for (const tenant of tenants ?? account.properties.tenants) { try { const token = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); - subscriptions.push(...(await this.getSubscriptions(account, new TokenCredentials(token.token, token.tokenType), tenant.id) || [])); + const subClient = new SubscriptionClient(new TokenCredentials(token.token, token.tokenType), { baseUri: account.properties.providerSettings.settings.armResource.endpoint }); + const newSubs = await subClient.subscriptions.list(); + subscriptions.push(...newSubs.map(newSub => { + return { + id: newSub.subscriptionId, + name: newSub.displayName, + tenant: tenant.id + }; + })); gotSubscriptions = true; } catch (error) { const errorMsg = localize('azure.resource.tenantSubscriptionsError', "Failed to get subscriptions for account {0} (tenant '{1}'). {2}", account.key.accountId, tenant.id, AzureResourceErrorMessageUtil.getErrorMessage(error)); diff --git a/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts b/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts index 7484d6a6df..c07bb9a2ac 100644 --- a/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts +++ b/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts @@ -43,7 +43,7 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode let subscriptions: azureResource.AzureResourceSubscription[] = []; if (this._isClearingCache) { - subscriptions = await this._subscriptionService.getAllSubscriptions(this.account); + subscriptions = await this._subscriptionService.getSubscriptions(this.account); this.updateCache(subscriptions); this._isClearingCache = false; } else { diff --git a/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts b/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts index 7e86e8475e..12ace17f1c 100644 --- a/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts +++ b/extensions/azurecore/src/azureResource/tree/flatAccountTreeNode.ts @@ -118,7 +118,7 @@ async function getSubscriptionInfo(account: AzureAccount, subscriptionService: I total: number, selected: number }> { - let subscriptions = await subscriptionService.getAllSubscriptions(account); + let subscriptions = await subscriptionService.getSubscriptions(account); const total = subscriptions.length; let selected = total; diff --git a/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts b/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts index f59f1fbb5e..cb535cb744 100644 --- a/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts +++ b/extensions/azurecore/src/azureResource/tree/flatTreeProvider.ts @@ -7,7 +7,6 @@ import * as vscode from 'vscode'; import * as azdata from 'azdata'; import { AppContext } from '../../appContext'; import * as nls from 'vscode-nls'; -import { TokenCredentials } from '@azure/ms-rest-js'; const localize = nls.loadMessageBundle(); import { TreeNode } from '../treeNode'; @@ -123,9 +122,7 @@ class ResourceLoader { for (const account of accounts) { for (const tenant of account.properties.tenants) { - const token = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); - - for (const subscription of await this.subscriptionService.getSubscriptions(account, new TokenCredentials(token.token, token.tokenType), tenant.id)) { + for (const subscription of await this.subscriptionService.getSubscriptions(account, [tenant.id])) { for (const providerId of await this.resourceService.listResourceProviderIds()) { for (const group of await this.resourceService.getRootChildren(providerId, account, subscription, subscription.tenant)) { const children = await this.resourceService.getChildren(providerId, group.resourceNode); diff --git a/extensions/azurecore/src/azureResource/utils.ts b/extensions/azurecore/src/azureResource/utils.ts index 0de5010748..715e20e642 100644 --- a/extensions/azurecore/src/azureResource/utils.ts +++ b/extensions/azurecore/src/azureResource/utils.ts @@ -276,11 +276,7 @@ export async function getSubscriptions(appContext: AppContext, account?: azdata. const subscriptionService = appContext.getService(AzureResourceServiceNames.subscriptionService); await Promise.all(account.properties.tenants.map(async (tenant: { id: string; }) => { try { - const response = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); - const token = response.token; - const tokenType = response.tokenType; - - result.subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType), tenant.id)); + result.subscriptions.push(...await subscriptionService.getSubscriptions(account, [tenant.id])); } catch (err) { const error = new Error(localize('azure.accounts.getSubscriptions.queryError', "Error fetching subscriptions for account {0} tenant {1} : {2}", account.displayInfo.displayName, diff --git a/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts b/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts index d5d0e8e88e..c1560bfe33 100644 --- a/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts +++ b/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts @@ -87,6 +87,7 @@ describe('AzureResourceAccountTreeNode.info', function (): void { mockExtensionContext = TypeMoq.Mock.ofType(); mockCacheService = TypeMoq.Mock.ofType(); mockSubscriptionService = TypeMoq.Mock.ofType(); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, undefined)).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService = TypeMoq.Mock.ofType(); mockTreeChangeHandler = TypeMoq.Mock.ofType(); @@ -128,7 +129,7 @@ describe('AzureResourceAccountTreeNode.info', function (): void { }); it('Should be correct when there are subscriptions listed.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getAllSubscriptions(mockAccount)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny())).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(undefined)); const accountTreeNodeLabel = `${mockAccount.displayInfo.displayName} (${mockSubscriptions.length} / ${mockSubscriptions.length} subscriptions)`; @@ -148,7 +149,7 @@ describe('AzureResourceAccountTreeNode.info', function (): void { }); it('Should be correct when there are subscriptions filtered.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getAllSubscriptions(mockAccount)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny())).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(mockFilteredSubscriptions)); const accountTreeNodeLabel = `${mockAccount.displayInfo.displayName} (${mockFilteredSubscriptions.length} / ${mockSubscriptions.length} subscriptions)`; @@ -184,7 +185,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { mockAppContext.registerService(AzureResourceServiceNames.subscriptionService, mockSubscriptionService.object); mockAppContext.registerService(AzureResourceServiceNames.subscriptionFilterService, mockSubscriptionFilterService.object); - sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken)); + sinon.stub(azdata.accounts, 'getAccountSecurityToken').resolves(mockToken); mockCacheService.setup((o) => o.generateKey(TypeMoq.It.isAnyString())).returns(() => generateGuid()); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache = mockSubscriptions); @@ -195,14 +196,14 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should load subscriptions from scratch and update cache when it is clearing cache.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getAllSubscriptions(mockAccount)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny())).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve([])); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); const children = await accountTreeNode.getChildren(); - mockSubscriptionService.verify((o) => o.getAllSubscriptions(mockAccount), TypeMoq.Times.once()); + mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(0)); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockSubscriptionFilterService.verify((o) => o.getSelectedSubscriptions(mockAccount), TypeMoq.Times.once()); @@ -228,7 +229,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should load subscriptions from cache when it is not clearing cache.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getAllSubscriptions(mockAccount)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny())).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(undefined)); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); @@ -237,7 +238,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { const children = await accountTreeNode.getChildren(); - mockSubscriptionService.verify((o) => o.getAllSubscriptions(mockAccount), TypeMoq.Times.once()); + mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); @@ -249,7 +250,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should handle when there is no subscriptions.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getAllSubscriptions(mockAccount)).returns(() => Promise.resolve([])); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny())).returns(() => Promise.resolve([])); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); @@ -265,7 +266,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should honor subscription filtering.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getAllSubscriptions(mockAccount)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny())).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(mockFilteredSubscriptions)); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); @@ -283,7 +284,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should handle errors.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getAllSubscriptions(mockAccount)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny())).returns(() => Promise.resolve(mockSubscriptions)); const mockError = 'Test error'; mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => { throw new Error(mockError); }); @@ -292,7 +293,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { const children = await accountTreeNode.getChildren(); - mockSubscriptionService.verify((o) => o.getAllSubscriptions(mockAccount), TypeMoq.Times.once()); + mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, TypeMoq.It.isAny()), TypeMoq.Times.once()); mockSubscriptionFilterService.verify((o) => o.getSelectedSubscriptions(mockAccount), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once());