From 112de46723e16423c251ac1b0543729ade744d05 Mon Sep 17 00:00:00 2001 From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com> Date: Thu, 15 Dec 2022 16:38:45 -0800 Subject: [PATCH] Fix ADAL to MSAL transition of connections and account list (#21425) --- .../src/account-provider/auths/azureAuth.ts | 47 +++++++++---------- .../account-provider/azureAccountProvider.ts | 45 +++++++++++------- .../browser/connectionManagementService.ts | 4 +- .../connection/browser/connectionWidget.ts | 22 +++++++-- 4 files changed, 72 insertions(+), 46 deletions(-) diff --git a/extensions/azurecore/src/account-provider/auths/azureAuth.ts b/extensions/azurecore/src/account-provider/auths/azureAuth.ts index 9aec664aee..2956351f63 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuth.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuth.ts @@ -318,23 +318,12 @@ export abstract class AzureAuth implements vscode.Disposable { * @returns The authentication result, including the access token */ public async getTokenMsal(accountId: string, azureResource: azdata.AzureResource, tenantId: string): Promise { - const cache = this.clientApplication.getTokenCache(); - if (!cache) { - Logger.error('Error: Could not fetch token cache.'); - return null; - } const resource = this.resources.find(s => s.azureResourceId === azureResource); if (!resource) { Logger.error(`Error: Could not fetch the azure resource ${azureResource} `); return null; } - let account: AccountInfo | null; - // if the accountId is a home ID, it will include a "." character - if (accountId.includes(".")) { - account = await cache.getAccountByHomeId(accountId); - } else { - account = await cache.getAccountByLocalId(accountId); - } + let account: AccountInfo | null = await this.getAccountFromMsalCache(accountId); if (!account) { Logger.error('Error: Could not fetch account when acquiring token'); return null; @@ -374,6 +363,23 @@ export abstract class AzureAuth implements vscode.Disposable { } } + public async getAccountFromMsalCache(accountId: string): Promise { + const cache = this.clientApplication.getTokenCache(); + if (!cache) { + Logger.error('Error: Could not fetch token cache.'); + return null; + } + + let account: AccountInfo | null = null; + // if the accountId is a home ID, it will include a "." character + if (accountId.includes(".")) { + account = await cache.getAccountByHomeId(accountId); + } else { + account = await cache.getAccountByLocalId(accountId); + } + return account; + } + public async getTokenAdal(tenant: Tenant, resource: Resource, postData: AuthorizationCodePostData | TokenPostData | RefreshTokenPostData): Promise { Logger.verbose('Fetching token for tenant {0}', tenant.id); const tokenUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/token`; @@ -817,26 +823,19 @@ export abstract class AzureAuth implements vscode.Disposable { // remove account based on authLibrary field, accounts added before this field was present will default to // ADAL method of account removal if (account.authLibrary === Constants.AuthLibrary.MSAL) { - return this.deleteAccountCacheMsal(account); + return await this.deleteAccountCacheMsal(account); } else { // fallback to ADAL by default - return this.deleteAccountCacheAdal(account); + return await this.deleteAccountCacheAdal(account); } } catch (ex) { - const msg = localize('azure.cacheErrrorRemove', "Error when removing your account from the cache."); - void vscode.window.showErrorMessage(msg); - Logger.error('Error when removing tokens.', ex); + // We need not prompt user for error if token could not be removed from cache. + Logger.error('Error when removing token from cache: ', ex); } } public async deleteAccountCacheMsal(account: azdata.AccountKey): Promise { const tokenCache = this.clientApplication.getTokenCache(); - let msalAccount: AccountInfo | null; - // if the accountId is a home ID, it will include a "." character - if (account.accountId.includes(".")) { - msalAccount = await tokenCache.getAccountByHomeId(account.accountId); - } else { - msalAccount = await tokenCache.getAccountByLocalId(account.accountId); - } + let msalAccount: AccountInfo | null = await this.getAccountFromMsalCache(account.accountId); if (!msalAccount) { Logger.error(`MSAL: Unable to find account ${account.accountId} for removal`); throw Error(`Unable to find account ${account.accountId}`); diff --git a/extensions/azurecore/src/account-provider/azureAccountProvider.ts b/extensions/azurecore/src/account-provider/azureAccountProvider.ts index a454c50b90..87f8786bb7 100644 --- a/extensions/azurecore/src/account-provider/azureAccountProvider.ts +++ b/extensions/azurecore/src/account-provider/azureAccountProvider.ts @@ -115,7 +115,13 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp } else { account.isStale = false; if (this.authLibrary === Constants.AuthLibrary.MSAL) { + // Check MSAL Cache before adding account, to mark it as stale if it is not present in cache + const accountInCache = await azureAuth.getAccountFromMsalCache(account.key.accountId); + if (!accountInCache) { + account.isStale = true; + } accounts.push(account); + } else { // fallback to ADAL as default accounts.push(await azureAuth.refreshAccessAdal(account)); } @@ -137,23 +143,30 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp private async _getAccountSecurityToken(account: AzureAccount, tenantId: string, resource: azdata.AzureResource): Promise { await this.initCompletePromise; const azureAuth = this.getAuthMethod(account); - Logger.pii(`Getting account security token for ${JSON.stringify(account.key)} (tenant ${tenantId}). Auth Method = ${azureAuth.userFriendlyName}`, [], []); - if (this.authLibrary === Constants.AuthLibrary.MSAL) { - let authResult = await azureAuth?.getTokenMsal(account.key.accountId, resource, tenantId); - if (!authResult || !authResult.account || !authResult.account.idTokenClaims) { - Logger.error(`MSAL: getToken call failed`); - throw Error('Failed to get token'); - } else { - const token: Token = { - key: authResult.account.homeAccountId, - token: authResult.accessToken, - tokenType: authResult.tokenType, - expiresOn: authResult.account.idTokenClaims.exp - }; - return token; + if (azureAuth) { + Logger.pii(`Getting account security token for ${JSON.stringify(account.key)} (tenant ${tenantId}). Auth Method = ${azureAuth.userFriendlyName}`, [], []); + if (this.authLibrary === Constants.AuthLibrary.MSAL) { + let authResult = await azureAuth.getTokenMsal(account.key.accountId, resource, tenantId); + if (!authResult || !authResult.account || !authResult.account.idTokenClaims) { + Logger.error(`MSAL: getToken call failed`); + throw Error('Failed to get token'); + } else { + const token: Token = { + key: authResult.account.homeAccountId, + token: authResult.accessToken, + tokenType: authResult.tokenType, + expiresOn: authResult.account.idTokenClaims.exp + }; + return token; + } + } else { // fallback to ADAL as default + return azureAuth.getAccountSecurityTokenAdal(account, tenantId, resource); } - } else { // fallback to ADAL as default - return azureAuth?.getAccountSecurityTokenAdal(account, tenantId, resource); + } else { + account.isStale = true; + Logger.error(`_getAccountSecurityToken: Authentication method not found for account ${account.displayInfo.displayName}`); + throw Error('Failed to get authentication method, please remove and re-add the account'); + } } diff --git a/src/sql/workbench/services/connection/browser/connectionManagementService.ts b/src/sql/workbench/services/connection/browser/connectionManagementService.ts index ebd5540781..b912b2ab62 100644 --- a/src/sql/workbench/services/connection/browser/connectionManagementService.ts +++ b/src/sql/workbench/services/connection/browser/connectionManagementService.ts @@ -899,7 +899,9 @@ export class ConnectionManagementService extends Disposable implements IConnecti const azureAccounts = accounts.filter(a => a.key.providerId.startsWith('azure')); if (azureAccounts && azureAccounts.length > 0) { let accountId = (connection.authenticationType === Constants.AuthenticationType.AzureMFA || connection.authenticationType === Constants.AuthenticationType.AzureMFAAndUser) ? connection.azureAccount : connection.userName; - let account = azureAccounts.find(account => account.key.accountId === accountId); + // For backwards compatibility with ADAL, we need to check if the account ID matches with tenant Id or just the account ID + // The OR case can be removed once we no longer support ADAL + let account = azureAccounts.find(account => account.key.accountId === accountId || account.key.accountId.split('.')[0] === accountId); if (account) { this._logService.debug(`Getting security token for Azure account ${account.key.accountId}`); if (account.isStale) { diff --git a/src/sql/workbench/services/connection/browser/connectionWidget.ts b/src/sql/workbench/services/connection/browser/connectionWidget.ts index aa21da6354..1bf1737f9d 100644 --- a/src/sql/workbench/services/connection/browser/connectionWidget.ts +++ b/src/sql/workbench/services/connection/browser/connectionWidget.ts @@ -627,7 +627,10 @@ export class ConnectionWidget extends lifecycle.Disposable { } private updateRefreshCredentialsLink(): void { - let chosenAccount = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); + // For backwards compatibility with ADAL, we need to check if the account ID matches with tenant Id or just the account ID + // The OR case can be removed once we no longer support ADAL + let chosenAccount = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value + || account.key.accountId.split('.')[0] === this._azureAccountDropdown.value); if (chosenAccount && chosenAccount.isStale) { this._tableContainer.classList.remove('hide-refresh-link'); } else { @@ -648,7 +651,10 @@ export class ConnectionWidget extends lifecycle.Disposable { await this.fillInAzureAccountOptions(); // If a new account was added find it and select it, otherwise select the first account - let newAccount = this._azureAccountList.find(option => !oldAccountIds.some(oldId => oldId === option.key.accountId)); + // For backwards compatibility with ADAL, we need to check if the account ID matches with tenant Id or just the account ID + // The OR case can be removed once we no longer support ADAL + let newAccount = this._azureAccountList.find(option => !oldAccountIds.some(oldId => oldId === option.key.accountId + || oldId.split('.')[0] === option.key.accountId)); if (newAccount) { this._azureAccountDropdown.selectWithOptionName(newAccount.key.accountId); } else { @@ -660,7 +666,10 @@ export class ConnectionWidget extends lifecycle.Disposable { // Display the tenant select box if needed const hideTenantsClassName = 'hide-azure-tenants'; - let selectedAccount = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); + // For backwards compatibility with ADAL, we need to check if the account ID matches with tenant Id or just the account ID + // The OR case can be removed once we no longer support ADAL + let selectedAccount = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value + || account.key.accountId.split('.')[0] === this._azureAccountDropdown.value); if (selectedAccount && selectedAccount.properties.tenants && selectedAccount.properties.tenants.length > 1) { // There are multiple tenants available so let the user select one let options = selectedAccount.properties.tenants.map(tenant => tenant.displayName); @@ -834,9 +843,12 @@ export class ConnectionWidget extends lifecycle.Disposable { let tenantId = connectionInfo.azureTenantId; let accountName = (this.authType === AuthenticationType.AzureMFA) ? connectionInfo.azureAccount : connectionInfo.userName; - this._azureAccountDropdown.selectWithOptionName(this.getModelValue(accountName)); + // For backwards compatibility with ADAL, we need to check if the account ID matches with tenant Id or just the account ID + // The OR case can be removed once we no longer support ADAL + let account = this._azureAccountList.find(account => account.key.accountId === this.getModelValue(accountName) + || account.key.accountId.split('.')[0] === this.getModelValue(accountName)); + this._azureAccountDropdown.selectWithOptionName(account.key.accountId); await this.onAzureAccountSelected(); - let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); if (account && account.properties.tenants && account.properties.tenants.length > 1) { let tenant = account.properties.tenants.find(tenant => tenant.id === tenantId); if (tenant) {