diff --git a/extensions/machine-learning/src/common/apiWrapper.ts b/extensions/machine-learning/src/common/apiWrapper.ts index 85a2efb1bc..a05f9f8f78 100644 --- a/extensions/machine-learning/src/common/apiWrapper.ts +++ b/extensions/machine-learning/src/common/apiWrapper.ts @@ -99,8 +99,8 @@ export class ApiWrapper { return azdata.accounts.getAllAccounts(); } - public getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{ [key: string]: any }> { - return azdata.accounts.getSecurityToken(account, resource); + public getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string, tokenType?: string } | undefined> { + return azdata.accounts.getAccountSecurityToken(account, tenant, resource); } public showQuickPick(items: T[] | Thenable, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable { diff --git a/extensions/machine-learning/src/modelManagement/azureModelRegistryService.ts b/extensions/machine-learning/src/modelManagement/azureModelRegistryService.ts index 0c1c46290e..d34b04237d 100644 --- a/extensions/machine-learning/src/modelManagement/azureModelRegistryService.ts +++ b/extensions/machine-learning/src/modelManagement/azureModelRegistryService.ts @@ -303,15 +303,12 @@ export class AzureModelRegistryService { if (this._amlClient) { return this._amlClient; } else { - const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement); + const tokens: { token: string, tokenType?: string } | undefined = await this._apiWrapper.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); let token: string = ''; let tokenType: string | undefined = undefined; - if (tokens && tenant.id in tokens) { - const tokenForId = tokens[tenant.id]; - if (tokenForId) { - token = tokenForId.token; - tokenType = tokenForId.tokenType; - } + if (tokens) { + token = tokens.token; + tokenType = tokens.tokenType; } const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options); if (apiVersion) {