diff --git a/extensions/azurecore/src/azureResource/utils.ts b/extensions/azurecore/src/azureResource/utils.ts index db5411b153..4b8da53e55 100644 --- a/extensions/azurecore/src/azureResource/utils.ts +++ b/extensions/azurecore/src/azureResource/utils.ts @@ -140,7 +140,11 @@ export async function getResourceGroups(appContext: AppContext, account?: azdata return result; } -export async function runResourceQuery(appContext: AppContext, account: azdata.Account, subscription: azureResource.AzureResourceSubscription, ignoreErrors: boolean = false, query: string) { +export async function runResourceQuery( + account: azdata.Account, + subscriptions: azureResource.AzureResourceSubscription[], + ignoreErrors: boolean = false, + query: string): Promise> { const result: ResourceQueryResult = { resources: [], errors: [] }; if (!account?.properties?.tenants || !isArray(account.properties.tenants)) { const error = new Error(localize('azure.accounts.runResourceQuery.errors.invalidAccount', "Invalid account")); @@ -151,51 +155,74 @@ export async function runResourceQuery { + if (!subscription.tenant) { + const error = new Error(localize('azure.accounts.runResourceQuery.errors.noTenantSpecifiedForSubscription', "Invalid tenant for subscription")); + if (!ignoreErrors) { + throw error; + } + result.errors.push(error); } - result.errors.push(error); + }); + if (result.errors.length > 0) { return result; } - const tokenResponse = await azdata.accounts.getAccountSecurityToken(account, subscription.tenant, azdata.AzureResource.ResourceManagement); - const token = tokenResponse.token; - const tokenType = tokenResponse.tokenType; - const credential = new TokenCredentials(token, tokenType); - - const resourceClient = new ResourceGraphClient(credential, { baseUri: account.properties.providerSettings.settings.armResource.endpoint }); - - const allResources: T[] = []; - let totalProcessed = 0; - - const doQuery = async (skipToken?: string) => { - const response = await resourceClient.resources({ - subscriptions: [subscription.id], - query, - options: { - resultFormat: 'objectArray', - skipToken: skipToken - } - }); - const resources: T[] = response.data; - totalProcessed += resources.length; - allResources.push(...resources); - if (response.skipToken && totalProcessed < response.totalRecords) { - await doQuery(response.skipToken); + // We need to get a different security token for each tenant to query the resources for the subscriptions on + // that tenant + for (let i = 0; i < account.properties.tenants.length; ++i) { + const tenant = account.properties.tenants[i]; + const tenantSubscriptions = subscriptions.filter(subscription => subscription.tenant === tenant.id); + if (tenantSubscriptions.length < 1) { + // We may not have all subscriptions or the tenant might not have any subscriptions - just ignore these ones + continue; } - }; - try { - await doQuery(); - } catch (err) { - console.error(err); - const error = new Error(localize('azure.accounts.runResourceQuery.errors.invalidQuery', "Invalid query")); - result.errors.push(error); - } - result.resources = allResources; - return result; + let resourceClient: ResourceGraphClient; + try { + const tokenResponse = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); + const token = tokenResponse.token; + const tokenType = tokenResponse.tokenType; + const credential = new TokenCredentials(token, tokenType); + + resourceClient = new ResourceGraphClient(credential, { baseUri: account.properties.providerSettings.settings.armResource.endpoint }); + } catch (err) { + console.error(err); + const error = new Error(localize('azure.accounts.runResourceQuery.errors.unableToFetchToken', "Unable to get token for tenant {0}", tenant.id)); + result.errors.push(error); + continue; + } + + const allResources: T[] = []; + let totalProcessed = 0; + + const doQuery = async (skipToken?: string) => { + const response = await resourceClient.resources({ + subscriptions: tenantSubscriptions.map(subscription => subscription.id), + query, + options: { + resultFormat: 'objectArray', + skipToken: skipToken + } + }); + const resources: T[] = response.data; + totalProcessed += resources.length; + allResources.push(...resources); + if (response.skipToken && totalProcessed < response.totalRecords) { + await doQuery(response.skipToken); + } + }; + try { + await doQuery(); + } catch (err) { + console.error(err); + const error = new Error(localize('azure.accounts.runResourceQuery.errors.invalidQuery', "Invalid query")); + result.errors.push(error); + } + result.resources.push(...allResources); + } + return result; } export async function getSubscriptions(appContext: AppContext, account?: azdata.Account, ignoreErrors: boolean = false): Promise { diff --git a/extensions/azurecore/src/azurecore.d.ts b/extensions/azurecore/src/azurecore.d.ts index 9de9321ee0..bf929bb7d6 100644 --- a/extensions/azurecore/src/azurecore.d.ts +++ b/extensions/azurecore/src/azurecore.d.ts @@ -72,7 +72,7 @@ declare module 'azurecore' { getRegionDisplayName(region?: string): string; provideResources(): azureResource.IAzureResourceProvider[]; - runGraphQuery(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, ignoreErrors: boolean, query: string): Promise>; + runGraphQuery(account: azdata.Account, subscriptions: azureResource.AzureResourceSubscription[], ignoreErrors: boolean, query: string): Promise>; } export type GetSubscriptionsResult = { subscriptions: azureResource.AzureResourceSubscription[], errors: Error[] }; diff --git a/extensions/azurecore/src/extension.ts b/extensions/azurecore/src/extension.ts index 4156da0c4e..406985ea01 100644 --- a/extensions/azurecore/src/extension.ts +++ b/extensions/azurecore/src/extension.ts @@ -110,10 +110,10 @@ export async function activate(context: vscode.ExtensionContext): Promise(account: azdata.Account, - subscription: azureResource.AzureResourceSubscription, + subscriptions: azureResource.AzureResourceSubscription[], ignoreErrors: boolean, query: string): Promise> { - return azureResourceUtils.runResourceQuery(appContext, account, subscription, ignoreErrors, query); + return azureResourceUtils.runResourceQuery(account, subscriptions, ignoreErrors, query); } }; } diff --git a/extensions/machine-learning/src/test/stubs.ts b/extensions/machine-learning/src/test/stubs.ts index aa7e897402..cd5b8623b6 100644 --- a/extensions/machine-learning/src/test/stubs.ts +++ b/extensions/machine-learning/src/test/stubs.ts @@ -8,7 +8,7 @@ import * as azurecore from 'azurecore'; import { azureResource } from 'azureResource'; export class AzurecoreApiStub implements azurecore.IExtension { - runGraphQuery(_account: azdata.Account, _subscription: azureResource.AzureResourceSubscription, _ignoreErrors: boolean, _query: string): Promise> { + runGraphQuery(_account: azdata.Account, _subscriptions: azureResource.AzureResourceSubscription[], _ignoreErrors: boolean, _query: string): Promise> { throw new Error('Method not implemented.'); } getSubscriptions(_account?: azdata.Account | undefined, _ignoreErrors?: boolean | undefined): Thenable { diff --git a/extensions/sql-migration/src/api/azure.ts b/extensions/sql-migration/src/api/azure.ts index 167f7e5ecd..105e53e392 100644 --- a/extensions/sql-migration/src/api/azure.ts +++ b/extensions/sql-migration/src/api/azure.ts @@ -39,7 +39,7 @@ export type SqlManagedInstance = AzureProduct; export async function getAvailableManagedInstanceProducts(account: azdata.Account, subscription: Subscription): Promise { const api = await getAzureCoreAPI(); - const result = await api.runGraphQuery(account, subscription, false, 'where type == "microsoft.sql/managedinstances"'); + const result = await api.runGraphQuery(account, [subscription], false, 'where type == "microsoft.sql/managedinstances"'); return result.resources; } @@ -47,7 +47,7 @@ export type SqlServer = AzureProduct; export async function getAvailableSqlServers(account: azdata.Account, subscription: Subscription): Promise { const api = await getAzureCoreAPI(); - const result = await api.runGraphQuery(account, subscription, false, 'where type == "microsoft.sql/servers"'); + const result = await api.runGraphQuery(account, [subscription], false, 'where type == "microsoft.sql/servers"'); return result.resources; } @@ -55,6 +55,6 @@ export type SqlVMServer = AzureProduct; export async function getAvailableSqlVMs(account: azdata.Account, subscription: Subscription): Promise { const api = await getAzureCoreAPI(); - const result = await api.runGraphQuery(account, subscription, false, 'where type == "microsoft.compute/virtualmachines" and properties.storageProfile.imageReference.publisher == "microsoftsqlserver"'); + const result = await api.runGraphQuery(account, [subscription], false, 'where type == "microsoft.compute/virtualmachines" and properties.storageProfile.imageReference.publisher == "microsoftsqlserver"'); return result.resources; }