diff --git a/extensions/azurecore/src/account-provider/auths/azureAuth.ts b/extensions/azurecore/src/account-provider/auths/azureAuth.ts index 5859fe32f3..dc4c5b5d13 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuth.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuth.ts @@ -26,8 +26,9 @@ import * as qs from 'qs'; import { AzureAuthError } from './azureAuthError'; import { AccountInfo, AuthenticationResult, InteractionRequiredAuthError, PublicClientApplication } from '@azure/msal-node'; import { HttpClient } from './httpClient'; -import { getProxyEnabledHttpClient } from '../../utils'; +import { getProxyEnabledHttpClient, getTenantIgnoreList, updateTenantIgnoreList } from '../../utils'; import { errorToPromptFailedResult } from './networkUtils'; +import { MsalCachePluginProvider } from '../utils/msalCachePlugin'; const localize = nls.loadMessageBundle(); export abstract class AzureAuth implements vscode.Disposable { @@ -46,6 +47,7 @@ export abstract class AzureAuth implements vscode.Disposable { constructor( protected readonly metadata: AzureAccountProviderMetadata, protected readonly tokenCache: SimpleTokenCache, + protected readonly msalCacheProvider: MsalCachePluginProvider, protected readonly context: vscode.ExtensionContext, protected clientApplication: PublicClientApplication, protected readonly uriEventEmitter: vscode.EventEmitter, @@ -119,7 +121,8 @@ export abstract class AzureAuth implements vscode.Disposable { const token: Token = { token: result.response.accessToken, key: result.response.account.homeAccountId, - tokenType: result.response.tokenType + tokenType: result.response.tokenType, + expiresOn: result.response.expiresOn!.getTime() / 1000 }; const tokenClaims = result.response.idTokenClaims; const account = await this.hydrateAccount(token, tokenClaims); @@ -228,7 +231,7 @@ export abstract class AzureAuth implements vscode.Disposable { const cachedTokens = await this.getSavedTokenAdal(tenant, resource, account.key); // Let's check to see if we can just use the cached tokens to return to the user - if (cachedTokens?.accessToken) { + if (cachedTokens) { let expiry = Number(cachedTokens.expiresOn); if (Number.isNaN(expiry)) { Logger.error('Expiration time was not defined. This is expected on first launch'); @@ -315,7 +318,8 @@ export abstract class AzureAuth implements vscode.Disposable { * (i.e. expired token, wrong scope, etc.), sends a request for a new token using the refresh token * @param accountId * @param azureResource - * @returns The authentication result, including the access token + * @returns The authentication result, including the access token. + * This function returns 'null' instead of 'undefined' by design as the same is returned by MSAL APIs in the flow (e.g. acquireTokenSilent). */ public async getTokenMsal(accountId: string, azureResource: azdata.AzureResource, tenantId: string): Promise { const resource = this.resources.find(s => s.azureResourceId === azureResource); @@ -325,6 +329,12 @@ export abstract class AzureAuth implements vscode.Disposable { return null; } + // The user wants to ignore this tenant. + if (getTenantIgnoreList().includes(tenantId)) { + Logger.info(`Tenant ${tenantId} found in the ignore list, authentication will not be attempted.`); + return null; + } + // Resource endpoint must end with '/' to form a valid scope for MSAL token request. const endpoint = resource.endpoint.endsWith('/') ? resource.endpoint : resource.endpoint + '/'; @@ -341,14 +351,14 @@ export abstract class AzureAuth implements vscode.Disposable { } // construct request - // forceRefresh needs to be set true here in order to fetch the correct token for non-full tenants, due to this issue + // forceRefresh needs to be set true here in order to fetch the correct token, due to this issue // https://github.com/AzureAD/microsoft-authentication-library-for-js/issues/3687 + // Even for full tenants, access token is often received expired - force refresh is necessary when token expires. const tokenRequest = { account: account, authority: `${this.loginEndpointUrl}${tenantId}`, scopes: newScope, - // Force Refresh when tenant is NOT full tenant or organizational id that this account belongs to. - forceRefresh: tenantId !== account.tenantId + forceRefresh: true }; try { return await this.clientApplication.acquireTokenSilent(tokenRequest); @@ -463,7 +473,7 @@ export abstract class AzureAuth implements vscode.Disposable { public async getTenantsMsal(token: string): Promise { const tenantUri = url.resolve(this.metadata.settings.armResource.endpoint, 'tenants?api-version=2019-11-01'); try { - Logger.verbose('Fetching tenants with uri {0}', tenantUri); + Logger.verbose(`Fetching tenants with uri: ${tenantUri}`); let tenantList: string[] = []; const tenantResponse = await this.httpClient.sendGetRequestAsync(tenantUri, { @@ -512,7 +522,7 @@ export abstract class AzureAuth implements vscode.Disposable { public async getTenantsAdal(token: AccessToken): Promise { const tenantUri = url.resolve(this.metadata.settings.armResource.endpoint, 'tenants?api-version=2019-11-01'); try { - Logger.verbose('Fetching tenants with URI: {0}', tenantUri); + Logger.verbose(`Fetching tenants with uri: ${tenantUri}`); let tenantList: string[] = []; const tenantResponse = await this.makeGetRequest(tenantUri, token.token); if (tenantResponse.status !== 200) { @@ -644,24 +654,14 @@ export abstract class AzureAuth implements vscode.Disposable { if (!tenant.displayName && !tenant.id) { throw new Error('Tenant did not have display name or id'); } - - const getTenantConfigurationSet = (): Set => { - const configuration = vscode.workspace.getConfiguration(Constants.AzureTenantConfigSection); - let values: string[] = configuration.get('filter') ?? []; - return new Set(values); - }; + const tenantIgnoreList = getTenantIgnoreList(); // The user wants to ignore this tenant. - if (getTenantConfigurationSet().has(tenant.id)) { + if (tenantIgnoreList.includes(tenant.id)) { Logger.info(`Tenant ${tenant.id} found in the ignore list, authentication will not be attempted.`); return false; } - const updateTenantConfigurationSet = async (set: Set): Promise => { - const configuration = vscode.workspace.getConfiguration('azure.tenant.config'); - await configuration.update('filter', Array.from(set), vscode.ConfigurationTarget.Global); - }; - interface ConsentMessageItem extends vscode.MessageItem { booleanResult: boolean; action?: (tenantId: string) => Promise; @@ -682,9 +682,8 @@ export abstract class AzureAuth implements vscode.Disposable { title: localize('azurecore.consentDialog.ignore', "Ignore Tenant"), booleanResult: false, action: async (tenantId: string) => { - let set = getTenantConfigurationSet(); - set.add(tenantId); - await updateTenantConfigurationSet(set); + tenantIgnoreList.push(tenantId); + await updateTenantIgnoreList(tenantIgnoreList); } }; @@ -828,6 +827,7 @@ export abstract class AzureAuth implements vscode.Disposable { } public async deleteAllCacheMsal(): Promise { this.clientApplication.clearCache(); + await this.msalCacheProvider.clearLocalCache(); } public async deleteAllCacheAdal(): Promise { const results = await this.tokenCache.findCredentials(''); @@ -852,17 +852,18 @@ export abstract class AzureAuth implements vscode.Disposable { } } - public async deleteAccountCacheMsal(account: azdata.AccountKey): Promise { + private async deleteAccountCacheMsal(accountKey: azdata.AccountKey): Promise { const tokenCache = this.clientApplication.getTokenCache(); - let msalAccount: AccountInfo | null = await this.getAccountFromMsalCache(account.accountId); + let msalAccount: AccountInfo | null = await this.getAccountFromMsalCache(accountKey.accountId); if (!msalAccount) { - Logger.error(`MSAL: Unable to find account ${account.accountId} for removal`); - throw Error(`Unable to find account ${account.accountId}`); + Logger.error(`MSAL: Unable to find account ${accountKey.accountId} for removal`); + throw Error(`Unable to find account ${accountKey.accountId}`); } await tokenCache.removeAccount(msalAccount); + await this.msalCacheProvider.clearAccountFromLocalCache(accountKey.accountId); } - public async deleteAccountCacheAdal(account: azdata.AccountKey): Promise { + private async deleteAccountCacheAdal(account: azdata.AccountKey): Promise { const results = await this.tokenCache.findCredentials(account.accountId); if (!results) { Logger.error('ADAL: Unable to find account for removal'); @@ -927,12 +928,22 @@ export interface Token extends AccountKey { /** * Access token expiry timestamp */ - expiresOn?: number; + expiresOn: number | undefined; /** * TokenType */ tokenType: string; + + /** + * Associated Tenant Id + */ + tenantId?: string; + + /** + * Resource to which token belongs to. + */ + resource?: azdata.AzureResource; } export interface TokenClaims { // https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens diff --git a/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts b/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts index 7b8b753c03..cb30e0b645 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts @@ -19,6 +19,7 @@ import * as http from 'http'; import * as qs from 'qs'; import { promises as fs } from 'fs'; import { PublicClientApplication, CryptoProvider, AuthorizationUrlRequest, AuthorizationCodeRequest, AuthenticationResult } from '@azure/msal-node'; +import { MsalCachePluginProvider } from '../utils/msalCachePlugin'; const localize = nls.loadMessageBundle(); @@ -43,12 +44,13 @@ export class AzureAuthCodeGrant extends AzureAuth { constructor( metadata: AzureAccountProviderMetadata, tokenCache: SimpleTokenCache, + msalCacheProvider: MsalCachePluginProvider, context: vscode.ExtensionContext, uriEventEmitter: vscode.EventEmitter, clientApplication: PublicClientApplication, authLibrary: string ) { - super(metadata, tokenCache, context, clientApplication, uriEventEmitter, AzureAuthType.AuthCodeGrant, AzureAuthCodeGrant.USER_FRIENDLY_NAME, authLibrary); + super(metadata, tokenCache, msalCacheProvider, context, clientApplication, uriEventEmitter, AzureAuthType.AuthCodeGrant, AzureAuthCodeGrant.USER_FRIENDLY_NAME, authLibrary); this.cryptoProvider = new CryptoProvider(); this.pkceCodes = { nonce: '', diff --git a/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts b/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts index 177e1ccae7..16ab9555c7 100644 --- a/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts +++ b/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts @@ -23,6 +23,7 @@ import { Deferred } from '../interfaces'; import { AuthenticationResult, DeviceCodeRequest, PublicClientApplication } from '@azure/msal-node'; import { SimpleTokenCache } from '../utils/simpleTokenCache'; import { Logger } from '../../utils/Logger'; +import { MsalCachePluginProvider } from '../utils/msalCachePlugin'; const localize = nls.loadMessageBundle(); @@ -49,12 +50,13 @@ export class AzureDeviceCode extends AzureAuth { constructor( metadata: AzureAccountProviderMetadata, tokenCache: SimpleTokenCache, + msalCacheProvider: MsalCachePluginProvider, context: vscode.ExtensionContext, uriEventEmitter: vscode.EventEmitter, clientApplication: PublicClientApplication, authLibrary: string ) { - super(metadata, tokenCache, context, clientApplication, uriEventEmitter, AzureAuthType.DeviceCode, AzureDeviceCode.USER_FRIENDLY_NAME, authLibrary); + super(metadata, tokenCache, msalCacheProvider, context, clientApplication, uriEventEmitter, AzureAuthType.DeviceCode, AzureDeviceCode.USER_FRIENDLY_NAME, authLibrary); this.pageTitle = localize('addAccount', "Add {0} account", this.metadata.displayName); } diff --git a/extensions/azurecore/src/account-provider/azureAccountProvider.ts b/extensions/azurecore/src/account-provider/azureAccountProvider.ts index a4fda4c2ab..1bd82942a6 100644 --- a/extensions/azurecore/src/account-provider/azureAccountProvider.ts +++ b/extensions/azurecore/src/account-provider/azureAccountProvider.ts @@ -21,6 +21,7 @@ import { AzureAuthCodeGrant } from './auths/azureAuthCodeGrant'; import { AzureDeviceCode } from './auths/azureDeviceCode'; import { filterAccounts } from '../azureResource/utils'; import * as Constants from '../constants'; +import { MsalCachePluginProvider } from './utils/msalCachePlugin'; const localize = nls.loadMessageBundle(); @@ -35,6 +36,7 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp tokenCache: SimpleTokenCache, context: vscode.ExtensionContext, clientApplication: PublicClientApplication, + private readonly msalCacheProvider: MsalCachePluginProvider, uriEventHandler: vscode.EventEmitter, private readonly authLibrary: string, private readonly forceDeviceCode: boolean = false @@ -71,10 +73,10 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp const deviceCodeMethod: boolean = configuration.get(Constants.AuthType.DeviceCode, false); if (codeGrantMethod === true && !this.forceDeviceCode) { - this.authMappings.set(AzureAuthType.AuthCodeGrant, new AzureAuthCodeGrant(metadata, tokenCache, context, uriEventHandler, this.clientApplication, this.authLibrary)); + this.authMappings.set(AzureAuthType.AuthCodeGrant, new AzureAuthCodeGrant(metadata, tokenCache, this.msalCacheProvider, context, uriEventHandler, this.clientApplication, this.authLibrary)); } if (deviceCodeMethod === true || this.forceDeviceCode) { - this.authMappings.set(AzureAuthType.DeviceCode, new AzureDeviceCode(metadata, tokenCache, context, uriEventHandler, this.clientApplication, this.authLibrary)); + this.authMappings.set(AzureAuthType.DeviceCode, new AzureDeviceCode(metadata, tokenCache, this.msalCacheProvider, context, uriEventHandler, this.clientApplication, this.authLibrary)); } if (codeGrantMethod === false && deviceCodeMethod === false && !this.forceDeviceCode) { console.error('No authentication methods selected'); @@ -146,6 +148,16 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp if (azureAuth) { Logger.piiSanitized(`Getting account security token for ${JSON.stringify(account.key)} (tenant ${tenantId}). Auth Method = ${azureAuth.userFriendlyName}`, [], []); if (this.authLibrary === Constants.AuthLibrary.MSAL) { + try { + // Fetch cached token from local cache if token is available and valid. + let accessToken = await this.msalCacheProvider.getTokenFromLocalCache(account.key.accountId, tenantId, resource); + if (this.isValidToken(accessToken)) { + return accessToken; + } // else fallback to fetching a new token. + } catch (e) { + // Log any error and move on to fetching fresh access token. + Logger.info(`Could not fetch access token from cache: ${e}, fetching new access token instead.`); + } tenantId = tenantId || account.properties.owningTenant.id; let authResult = await azureAuth.getTokenMsal(account.key.accountId, resource, tenantId); if (this.isAuthenticationResult(authResult) && authResult.account && authResult.account.idTokenClaims) { @@ -153,8 +165,15 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp key: authResult.account.homeAccountId, token: authResult.accessToken, tokenType: authResult.tokenType, - expiresOn: authResult.account.idTokenClaims.exp + expiresOn: authResult.account.idTokenClaims.exp!, + tenantId: tenantId, + resource: resource }; + try { + await this.msalCacheProvider.writeTokenToLocalCache(token); + } catch (e) { + Logger.error(`Could not save access token to local cache: ${e}, this might cause throttling of AAD requests.`); + } return token; } else { Logger.error(`MSAL: getToken call failed`); @@ -172,7 +191,6 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp account.isStale = true; Logger.error(`_getAccountSecurityToken: Authentication method not found for account ${account.displayInfo.displayName}`); throw Error('Failed to get authentication method, please remove and re-add the account'); - } } @@ -192,6 +210,16 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp } } + /** + * Validates if access token is still valid by checking it's expiration time has a threshold of atleast 2 mins. + * @param accessToken Access token to be validated + * @returns True if access token is valid. + */ + private isValidToken(accessToken: Token | undefined): boolean { + const currentTime = new Date().getTime() / 1000; + return (accessToken !== undefined && accessToken.expiresOn !== undefined + && Number(accessToken.expiresOn) - currentTime > 2 * 60); // threshold = 2 mins + } private async _getSecurityToken(account: AzureAccount, resource: azdata.AzureResource): Promise { void vscode.window.showInformationMessage(localize('azure.deprecatedGetSecurityToken', "A call was made to azdata.accounts.getSecurityToken, this method is deprecated and will be removed in future releases. Please use getAccountSecurityToken instead.")); @@ -254,7 +282,7 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp return this.prompt(); } - clear(accountKey: azdata.AccountKey): Thenable { + async clear(accountKey: azdata.AccountKey): Promise { return this._clear(accountKey); } diff --git a/extensions/azurecore/src/account-provider/azureAccountProviderService.ts b/extensions/azurecore/src/account-provider/azureAccountProviderService.ts index a71a0673c1..f052723a5b 100644 --- a/extensions/azurecore/src/account-provider/azureAccountProviderService.ts +++ b/extensions/azurecore/src/account-provider/azureAccountProviderService.ts @@ -198,7 +198,8 @@ export class AzureAccountProviderService implements vscode.Disposable { this.clientApplication = new PublicClientApplication(msalConfiguration); let accountProvider = new AzureAccountProvider(provider.metadata as AzureAccountProviderMetadata, - simpleTokenCache, this._context, this.clientApplication, this._uriEventHandler, this._authLibrary, isSaw); + simpleTokenCache, this._context, this.clientApplication, this._cachePluginProvider, + this._uriEventHandler, this._authLibrary, isSaw); this._accountProviders[provider.metadata.id] = accountProvider; this._accountDisposals[provider.metadata.id] = azdata.accounts.registerAccountProvider(provider.metadata, accountProvider); } catch (e) { diff --git a/extensions/azurecore/src/account-provider/interfaces.ts b/extensions/azurecore/src/account-provider/interfaces.ts index b62a5c6a39..3ec215429f 100644 --- a/extensions/azurecore/src/account-provider/interfaces.ts +++ b/extensions/azurecore/src/account-provider/interfaces.ts @@ -26,37 +26,6 @@ export interface Subscription { displayName: string } -/** - * Token returned from a request for an access token - */ -export interface AzureAccountSecurityToken { - /** - * Access token, itself - */ - token: string; - - /** - * Date that the token expires on - */ - expiresOn: Date | string; - - /** - * Name of the resource the token is good for (ie, management.core.windows.net) - */ - resource: string; - - /** - * Type of the token (pretty much always 'Bearer') - */ - tokenType: string; -} - -/** - * Azure account security token maps a tenant ID to the information returned from a request to get - * an access token. The list of tenants correspond to the tenants in the account properties. - */ -export type AzureAccountSecurityTokenCollection = { [tenantId: string]: AzureAccountSecurityToken }; - export interface Deferred { resolve: (result: T | Promise) => void; reject: (reason: E) => void; diff --git a/extensions/azurecore/src/account-provider/utils/fileDatabase.ts b/extensions/azurecore/src/account-provider/utils/fileDatabase.ts index 0422b3a4f1..e9fae2e412 100644 --- a/extensions/azurecore/src/account-provider/utils/fileDatabase.ts +++ b/extensions/azurecore/src/account-provider/utils/fileDatabase.ts @@ -6,7 +6,7 @@ import { promises as fs, constants as fsConstants } from 'fs'; import { Logger } from '../../utils/Logger'; -export type ReadWriteHook = (contents: string) => Promise; +export type ReadWriteHook = (contents: string, resetOnError?: boolean) => Promise; const noOpHook: ReadWriteHook = async (contents): Promise => { return contents; }; @@ -97,7 +97,7 @@ export class FileDatabase { try { await fs.access(this.dbPath, fsConstants.R_OK | fsConstants.R_OK); fileContents = await fs.readFile(this.dbPath, { encoding: 'utf8' }); - fileContents = await this.readHook(fileContents); + fileContents = await this.readHook(fileContents, true); } catch (ex) { Logger.error(`Error occurred when initializing File Database from file system cache, ADAL cache will be reset: ${ex}`); await this.createFile(); diff --git a/extensions/azurecore/src/account-provider/utils/fileEncryptionHelper.ts b/extensions/azurecore/src/account-provider/utils/fileEncryptionHelper.ts index 10774db56a..e588ca3f89 100644 --- a/extensions/azurecore/src/account-provider/utils/fileEncryptionHelper.ts +++ b/extensions/azurecore/src/account-provider/utils/fileEncryptionHelper.ts @@ -79,7 +79,7 @@ export class FileEncryptionHelper { return cipherText; } - fileOpener = async (content: string): Promise => { + fileOpener = async (content: string, resetOnError?: boolean): Promise => { try { if (!this._keyBuffer || !this._ivBuffer) { await this.init(); @@ -97,13 +97,15 @@ export class FileEncryptionHelper { return `${decipherIv.update(plaintext, this._binaryEncoding, 'utf8')}${decipherIv.final('utf8')}`; } catch (ex) { Logger.error(`FileEncryptionHelper: Error occurred when decrypting data, IV/KEY will be reset: ${ex}`); - // Reset IV/Keys if crypto cannot encrypt/decrypt data. - // This could be a possible case of corruption of expected iv/key combination - await this.deleteEncryptionKey(this._ivCredId); - await this.deleteEncryptionKey(this._keyCredId); - this._ivBuffer = undefined; - this._keyBuffer = undefined; - await this.init(); + if (resetOnError) { + // Reset IV/Keys if crypto cannot encrypt/decrypt data. + // This could be a possible case of corruption of expected iv/key combination + await this.deleteEncryptionKey(this._ivCredId); + await this.deleteEncryptionKey(this._keyCredId); + this._ivBuffer = undefined; + this._keyBuffer = undefined; + await this.init(); + } // Throw error so cache file can be reset to empty. throw new Error(`Decryption failed with error: ${ex}`); } diff --git a/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts b/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts index 54c1ad9943..46005ce7c4 100644 --- a/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts +++ b/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts @@ -10,28 +10,49 @@ import * as lockFile from 'lockfile'; import * as path from 'path'; import * as azdata from 'azdata'; import * as vscode from 'vscode'; -import { AccountsClearTokenCacheCommand, AuthLibrary } from '../../constants'; +import { AccountsClearTokenCacheCommand, AuthLibrary, LocalCacheSuffix, LockFileSuffix } from '../../constants'; import { Logger } from '../../utils/Logger'; import { FileEncryptionHelper } from './fileEncryptionHelper'; import { CacheEncryptionKeys } from 'azurecore'; +import { Token } from '../auths/azureAuth'; + +interface CacheConfiguration { + name: string, + cacheFilePath: string, + lockFilePath: string, + lockTaken: boolean +} + +interface LocalAccountCache { + tokens: Token[]; +} export class MsalCachePluginProvider { constructor( private readonly _serviceName: string, - private readonly _msalFilePath: string, + msalFilePath: string, private readonly _credentialService: azdata.CredentialProvider, private readonly _onEncryptionKeysUpdated: vscode.EventEmitter ) { - this._msalFilePath = path.join(this._msalFilePath, this._serviceName); this._fileEncryptionHelper = new FileEncryptionHelper(AuthLibrary.MSAL, this._credentialService, this._serviceName, this._onEncryptionKeysUpdated); + this._msalCacheConfiguration = { + name: 'MSAL', + cacheFilePath: path.join(msalFilePath, this._serviceName), + lockFilePath: path.join(msalFilePath, this._serviceName) + LockFileSuffix, + lockTaken: false + } + this._localCacheConfiguration = { + name: 'Local', + cacheFilePath: path.join(msalFilePath, this._serviceName) + LocalCacheSuffix, + lockFilePath: path.join(msalFilePath, this._serviceName) + LocalCacheSuffix + LockFileSuffix, + lockTaken: false + } } - private _lockTaken: boolean = false; private _fileEncryptionHelper: FileEncryptionHelper; - - private getLockfilePath(): string { - return this._msalFilePath + '.lockfile'; - } + private _msalCacheConfiguration: CacheConfiguration; + private _localCacheConfiguration: CacheConfiguration; + private _emptyLocalCache: LocalAccountCache = { tokens: [] }; public async init(): Promise { await this._fileEncryptionHelper.init(); @@ -42,52 +63,22 @@ export class MsalCachePluginProvider { } public getCachePlugin(): ICachePlugin { - const lockFilePath = this.getLockfilePath(); const beforeCacheAccess = async (cacheContext: TokenCacheContext): Promise => { - await this.waitAndLock(lockFilePath); try { - const cache = await fsPromises.readFile(this._msalFilePath, { encoding: 'utf8' }); - const decryptedData = await this._fileEncryptionHelper.fileOpener(cache!); - try { - cacheContext.tokenCache.deserialize(decryptedData); - } catch (e) { - // Handle deserialization error in cache file in case file gets corrupted. - // Clearing cache here will ensure account is marked stale so re-authentication can be triggered. - Logger.verbose(`MsalCachePlugin: Error occurred when trying to read cache file, file will be deleted: ${e.message}`); - await fsPromises.unlink(this._msalFilePath); - } - Logger.verbose(`MsalCachePlugin: Token read from cache successfully.`); + const decryptedData = await this.readCache(this._msalCacheConfiguration); + cacheContext.tokenCache.deserialize(decryptedData); } catch (e) { - if (e.code === 'ENOENT') { - // File doesn't exist, log and continue - Logger.verbose(`MsalCachePlugin: Cache file not found on disk: ${e.code}`); - } - else { - Logger.error(`MsalCachePlugin: Failed to read from cache file: ${e}`); - Logger.verbose(`MsalCachePlugin: Error occurred when trying to read cache file, file will be deleted: ${e.message}`); - await fsPromises.unlink(this._msalFilePath); - } - } finally { - lockFile.unlockSync(lockFilePath); - this._lockTaken = false; + // Handle deserialization error in cache file in case file gets corrupted. + // Clearing cache here will ensure account is marked stale so re-authentication can be triggered. + Logger.verbose(`MsalCachePlugin: Error occurred when trying to read cache file, file will be deleted: ${e.message}`); + await fsPromises.unlink(this._msalCacheConfiguration.cacheFilePath); } } const afterCacheAccess = async (cacheContext: TokenCacheContext): Promise => { if (cacheContext.cacheHasChanged) { - await this.waitAndLock(lockFilePath); - try { - const data = cacheContext.tokenCache.serialize(); - const encryptedData = await this._fileEncryptionHelper.fileSaver(data!); - await fsPromises.writeFile(this._msalFilePath, encryptedData, { encoding: 'utf8' }); - Logger.verbose(`MsalCachePlugin: Token written to cache successfully.`); - } catch (e) { - Logger.error(`MsalCachePlugin: Failed to write to cache file. ${e}`); - throw e; - } finally { - lockFile.unlockSync(lockFilePath); - this._lockTaken = false; - } + const data = cacheContext.tokenCache.serialize(); + await this.writeCache(data, this._msalCacheConfiguration); } }; @@ -102,14 +93,134 @@ export class MsalCachePluginProvider { }; } - private async waitAndLock(lockFilePath: string): Promise { + /** + * Fetches access token from local cache, before accessing MSAL Cache. + * @param accountId Account Id for token owner. + * @param tenantId Tenant Id to which token belongs to. + * @param resource Resource Id to which token belongs to. + * @returns Access Token. + */ + public async getTokenFromLocalCache(accountId: string, tenantId: string, resource: azdata.AzureResource): Promise { + let cache = JSON.parse(await this.readCache(this._localCacheConfiguration)) as LocalAccountCache; + let token = cache?.tokens?.find(token => ( + token.key === accountId && + token.tenantId === tenantId && + token.resource === resource + )); + return token; + } + + /** + * Updates local cache with newly fetched access token to prevent throttling of AAD requests. + * @param token Access token to be written to cache file. + */ + public async writeTokenToLocalCache(token: Token): Promise { + let updateCount = 0; + let cache: LocalAccountCache; + cache = JSON.parse(await this.readCache(this._localCacheConfiguration)) as LocalAccountCache; + if (cache?.tokens) { + cache.tokens.forEach(t => { + if (t.key === token.key && t.tenantId === token.tenantId && t.resource === token.resource + ) { + // Update token + t = token; + updateCount++; + } + }); + } else { + // Initialize token cache + cache = this._emptyLocalCache; + } + + if (updateCount === 0) { + // No tokens were updated, add new token. + cache.tokens.push(token); + updateCount = 1; + } + + if (updateCount === 1) { + await this.writeCache(JSON.stringify(cache), this._localCacheConfiguration); + } + else { + Logger.info(`Found multiple tokens in local cache, cache will be reset.`); + // Reset cache as we don't expect multiple tokens to be stored for same combination. + await this.writeCache(JSON.stringify(this._emptyLocalCache), this._localCacheConfiguration); + } + } + + /** + * Removes associated tokens for account, to be called when account is deleted. + * @param accountId Account ID + */ + public async clearAccountFromLocalCache(accountId: string): Promise { + let cache = JSON.parse(await this.readCache(this._localCacheConfiguration)) as LocalAccountCache; + let tokenIndices: number[] = []; + if (cache?.tokens) { + cache.tokens.forEach((t, i) => { + if (t.key === accountId) { + tokenIndices.push(i); + } + }); + } + tokenIndices.forEach(i => { + cache.tokens.splice(i); + }) + Logger.info(`Local Cache cleared for account, ${tokenIndices.length} tokens were cleared.`); + } + + /** + * Clears local access token cache. + */ + public async clearLocalCache(): Promise { + await this.writeCache(JSON.stringify({ tokens: [] }), this._localCacheConfiguration); + } + + //#region Private helper methods + private async writeCache(fileContents: string, config: CacheConfiguration): Promise { + config.lockTaken = await this.waitAndLock(config.lockFilePath, config.lockTaken); + try { + const encryptedCache = await this._fileEncryptionHelper.fileSaver(fileContents); + await fsPromises.writeFile(config.cacheFilePath, encryptedCache, { encoding: 'utf8' }); + } catch (e) { + Logger.error(`MsalCachePlugin: Failed to write to '${config.name}' cache file: ${e}`); + throw e; + } finally { + lockFile.unlockSync(config.lockFilePath); + config.lockTaken = false; + } + } + + private async readCache(config: CacheConfiguration): Promise { + config.lockTaken = await this.waitAndLock(config.lockFilePath, config.lockTaken); + try { + const cache = await fsPromises.readFile(config.cacheFilePath, { encoding: 'utf8' }); + const decryptedData = await this._fileEncryptionHelper.fileOpener(cache!, true); + return decryptedData; + } catch (e) { + if (e.code === 'ENOENT') { + // File doesn't exist, log and continue + Logger.verbose(`MsalCachePlugin: Cache file for '${config.name}' cache not found on disk: ${e.code}`); + } + else { + Logger.error(`MsalCachePlugin: Failed to read from cache file: ${e}`); + Logger.verbose(`MsalCachePlugin: Error occurred when trying to read cache file, file will be deleted: ${e.message}`); + await fsPromises.unlink(config.cacheFilePath); + } + return '{}'; // Return empty json string if cache not read. + } finally { + lockFile.unlockSync(config.lockFilePath); + config.lockTaken = false; + } + } + + private async waitAndLock(lockFilePath: string, lockTaken: boolean): Promise { // Make 500 retry attempts with 100ms wait time between each attempt to allow enough time for the lock to be released. const retries = 500; const retryWait = 100; // We cannot rely on lockfile.lockSync() to clear stale lockfile, // so we check if the lockfile exists and if it does, calling unlockSync() will clear it. - if (lockFile.checkSync(lockFilePath) && !this._lockTaken) { + if (lockFile.checkSync(lockFilePath) && !lockTaken) { lockFile.unlockSync(lockFilePath); Logger.verbose(`MsalCachePlugin: Stale lockfile found and has been removed.`); } @@ -120,7 +231,7 @@ export class MsalCachePluginProvider { // Use lockfile.lockSync() to ensure only one process is accessing the cache at a time. // lockfile.lock() does not wait for async callback promise to resolve. lockFile.lockSync(lockFilePath); - this._lockTaken = true; + lockTaken = true; break; } catch (e) { if (retryAttempt === retries) { @@ -132,5 +243,7 @@ export class MsalCachePluginProvider { await new Promise(resolve => setTimeout(resolve, retryWait)); } } + return lockTaken; } + //#endregion } diff --git a/extensions/azurecore/src/azurecore.d.ts b/extensions/azurecore/src/azurecore.d.ts index dff4de72f0..14bd2c4f16 100644 --- a/extensions/azurecore/src/azurecore.d.ts +++ b/extensions/azurecore/src/azurecore.d.ts @@ -35,8 +35,11 @@ declare module 'azurecore' { /** * Auth type of azure used to authenticate this account. */ - azureAuthType?: AzureAuthType + azureAuthType?: AzureAuthType; + /** + * Provider settings for account. + */ providerSettings: AzureAccountProviderMetadata; /** @@ -53,7 +56,6 @@ declare module 'azurecore' { * A list of tenants (aka directories) that the account belongs to */ tenants: Tenant[]; - } export const enum AzureAuthType { diff --git a/extensions/azurecore/src/constants.ts b/extensions/azurecore/src/constants.ts index 55407edc94..8d64ba4ee2 100644 --- a/extensions/azurecore/src/constants.ts +++ b/extensions/azurecore/src/constants.ts @@ -37,6 +37,8 @@ export const TenantSection = 'tenant'; export const AzureTenantConfigSection = AzureSection + '.' + TenantSection + '.' + ConfigSection; +export const Filter = 'filter'; + export const NoSystemKeyChainSection = 'noSystemKeychain'; export const oldMsalCacheFileName = 'azureTokenCacheMsal-azure_publicCloud'; @@ -72,6 +74,10 @@ export const MSALCacheName = 'accessTokenCache'; export const DefaultAuthLibrary = 'MSAL'; +export const LocalCacheSuffix = '.local'; + +export const LockFileSuffix = '.lockfile'; + export enum BuiltInCommands { SetContext = 'setContext' } diff --git a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts index 42133981a5..a836e65993 100644 --- a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts +++ b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts @@ -19,7 +19,8 @@ let azureAuthCodeGrant: TypeMoq.IMock; const mockToken: Token = { key: 'someUniqueId', token: 'test_token', - tokenType: 'Bearer' + tokenType: 'Bearer', + expiresOn: new Date().getTime() / 1000 + (60 * 60) // 1 hour from now. }; let mockAccessToken: AccessToken; let mockRefreshToken: RefreshToken; diff --git a/extensions/azurecore/src/utils.ts b/extensions/azurecore/src/utils.ts index 3a6a332848..3e111b56f0 100644 --- a/extensions/azurecore/src/utils.ts +++ b/extensions/azurecore/src/utils.ts @@ -136,10 +136,29 @@ export function getResourceTypeDisplayName(type: string): string { } return type; } + function getHttpConfiguration(): vscode.WorkspaceConfiguration { return vscode.workspace.getConfiguration(constants.httpConfigSectionName); } +/** + * Gets tenants to be ignored. + * @returns Tenants configured in ignore list + */ +export function getTenantIgnoreList(): string[] { + const configuration = vscode.workspace.getConfiguration(constants.AzureTenantConfigSection); + return configuration.get(constants.Filter) ?? []; +} + +/** + * Updates tenant ignore list in global settings. + * @param tenantIgnoreList Tenants to be configured in ignore list + */ +export async function updateTenantIgnoreList(tenantIgnoreList: string[]): Promise { + const configuration = vscode.workspace.getConfiguration(constants.AzureTenantConfigSection); + await configuration.update(constants.Filter, tenantIgnoreList, vscode.ConfigurationTarget.Global); +} + export function getResourceTypeIcon(appContext: AppContext, type: string): string { switch (type) { case azureResource.AzureResourceType.sqlServer: diff --git a/src/sql/platform/accounts/test/common/testAccountManagementService.ts b/src/sql/platform/accounts/test/common/testAccountManagementService.ts index 88f055d65b..2710323d50 100644 --- a/src/sql/platform/accounts/test/common/testAccountManagementService.ts +++ b/src/sql/platform/accounts/test/common/testAccountManagementService.ts @@ -107,6 +107,7 @@ export class AccountProviderStub implements azdata.AccountProvider { getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> { return Promise.resolve(undefined!); } + initialize(storedAccounts: azdata.Account[]): Thenable { return Promise.resolve(storedAccounts); } diff --git a/src/sql/workbench/api/common/extHostAccountManagement.ts b/src/sql/workbench/api/common/extHostAccountManagement.ts index 33c6527a34..cb3d0dae7c 100644 --- a/src/sql/workbench/api/common/extHostAccountManagement.ts +++ b/src/sql/workbench/api/common/extHostAccountManagement.ts @@ -63,11 +63,9 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape { this._proxy.$accountUpdated(updatedAccount); } - - public $getAllAccounts(): Thenable { - return this.getAllProvidersAndAccounts().then(providersAndAccounts => { - return providersAndAccounts.map(providerAndAccount => providerAndAccount.account); - }); + public async $getAllAccounts(): Promise { + let providersAndAccounts = await this.getAllProvidersAndAccounts(); + return providersAndAccounts.map(providerAndAccount => providerAndAccount.account); } private async getAllProvidersAndAccounts(): Promise { @@ -94,32 +92,30 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape { return resultProviderAndAccounts; } - public override $getSecurityToken(account: azdata.Account, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{}> { - return this.getAllProvidersAndAccounts().then(providerAndAccounts => { - const providerAndAccount = providerAndAccounts.find(providerAndAccount => providerAndAccount.account.key.accountId === account.key.accountId); - if (providerAndAccount) { - return providerAndAccount.provider.getSecurityToken(account, resource); - } - throw new Error(`Account ${account.key.accountId} not found.`); - }); + public override async $getSecurityToken(account: azdata.Account, resource: azdata.AzureResource = AzureResource.ResourceManagement): Promise<{}> { + let providerAndAccounts = await this.getAllProvidersAndAccounts(); + const providerAndAccount = providerAndAccounts.find(providerAndAccount => providerAndAccount.account.key.accountId === account.key.accountId); + if (providerAndAccount) { + return providerAndAccount.provider.getSecurityToken(account, resource); + } + throw new Error(`Account ${account.key.accountId} not found.`); } - public override $getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable { - return this.getAllProvidersAndAccounts().then(providerAndAccounts => { - const providerAndAccount = providerAndAccounts.find(providerAndAccount => providerAndAccount.account.key.accountId === account.key.accountId); - if (providerAndAccount) { - return providerAndAccount.provider.getAccountSecurityToken(account, tenant, resource); - } - throw new Error(`Account ${account.key.accountId} not found.`); - }); + public override async $getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource = AzureResource.ResourceManagement): Promise { + let providerAndAccounts = await this.getAllProvidersAndAccounts(); + const providerAndAccount = providerAndAccounts.find(providerAndAccount => providerAndAccount.account.key.accountId === account.key.accountId); + if (providerAndAccount) { + return await providerAndAccount.provider.getAccountSecurityToken(account, tenant, resource); + } + throw Error(`Account ${account.key.accountId} not found.`); } public get onDidChangeAccounts(): Event { return this._onDidChangeAccounts.event; } - public override $accountsChanged(handle: number, accounts: azdata.Account[]): Thenable { - return Promise.resolve(this._onDidChangeAccounts.fire({ accounts: accounts })); + public override async $accountsChanged(handle: number, accounts: azdata.Account[]): Promise { + return this._onDidChangeAccounts.fire({ accounts: accounts }); } public $registerAccountProvider(providerMetadata: azdata.AccountProviderMetadata, provider: azdata.AccountProvider): Disposable { diff --git a/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts b/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts index cddb288b03..4ccef56f59 100644 --- a/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts +++ b/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts @@ -361,8 +361,8 @@ export class AccountManagementService implements IAccountManagementService { * @return Promise to return the security token */ public getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Promise { - return this.doWithProvider(account.key.providerId, provider => { - return Promise.resolve(provider.provider.getAccountSecurityToken(account, tenant, resource)); + return this.doWithProvider(account.key.providerId, async provider => { + return await provider.provider.getAccountSecurityToken(account, tenant, resource); }); }