From 0bc2a50d78029c3cec1fdaa3df304a3518da6954 Mon Sep 17 00:00:00 2001 From: Hai Cao Date: Thu, 2 Sep 2021 17:24:45 -0700 Subject: [PATCH] Add token expiration handling for AzureMFA auth (#16936) * refresh azure account token if it's expired before sending query/list requests * fix several connection checks && add more logging * fix async UI glitch during token refreshing * cleanup * minor fix * add test for refreshAzureAccountTokenIfNecessary * address comments * comments * comments * comments * error handling --- .../src/account-provider/auths/azureAuth.ts | 9 +++ .../account-provider/auths/azureAuth.test.ts | 4 +- src/sql/azdata.proposed.d.ts | 9 +++ .../connection/common/connectionManagement.ts | 1 + .../connection/common/connectionProfile.ts | 4 + .../common/testConnectionManagementService.ts | 4 + .../contrib/query/browser/queryActions.ts | 4 + .../browser/connectionManagementService.ts | 61 +++++++++++++- .../connectionManagementService.test.ts | 79 ++++++++++++++++++- 9 files changed, 171 insertions(+), 4 deletions(-) diff --git a/extensions/azurecore/src/account-provider/auths/azureAuth.ts b/extensions/azurecore/src/account-provider/auths/azureAuth.ts index b543da4222..0945504d04 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuth.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuth.ts @@ -184,17 +184,20 @@ export abstract class AzureAuth implements vscode.Disposable { const currentTime = new Date().getTime() / 1000; let accessToken = cachedTokens.accessToken; + let expiresOn = Number(cachedTokens.expiresOn); const remainingTime = expiry - currentTime; const maxTolerance = 2 * 60; // two minutes if (remainingTime < maxTolerance) { const result = await this.refreshToken(tenant, resource, cachedTokens.refreshToken); accessToken = result.accessToken; + expiresOn = Number(result.expiresOn); } // Let's just return here. if (accessToken) { return { ...accessToken, + expiresOn: expiresOn, tokenType: 'Bearer' }; } @@ -214,6 +217,7 @@ export abstract class AzureAuth implements vscode.Disposable { if (result.accessToken) { return { ...result.accessToken, + expiresOn: Number(result.expiresOn), tokenType: 'Bearer' }; } @@ -674,6 +678,11 @@ export interface Token extends AccountKey { */ token: string; + /** + * Access token expiry timestamp + */ + expiresOn?: number; + /** * TokenType */ diff --git a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts index 12d727ea23..2123cb2bd3 100644 --- a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts +++ b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts @@ -96,8 +96,8 @@ describe('Azure Authentication', function () { it('token recieved for ossRdbmns resource', async function () { azureAuthCodeGrant.setup(x => x.getTenants(mockToken)).returns(() => { return Promise.resolve([ - mockTenant - ]); + mockTenant + ]); }); azureAuthCodeGrant.setup(x => x.getTokenHelper(mockTenant, provider.settings.ossRdbmsResource, TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { return Promise.resolve({ diff --git a/src/sql/azdata.proposed.d.ts b/src/sql/azdata.proposed.d.ts index 6709eabff6..7ca20e36aa 100644 --- a/src/sql/azdata.proposed.d.ts +++ b/src/sql/azdata.proposed.d.ts @@ -947,4 +947,13 @@ declare module 'azdata' { */ parentTypeName?: string; } + + export namespace accounts { + export interface AccountSecurityToken { + /** + * Access token expiry timestamp + */ + expiresOn?: number + } + } } diff --git a/src/sql/platform/connection/common/connectionManagement.ts b/src/sql/platform/connection/common/connectionManagement.ts index 479810cb56..6b9778926d 100644 --- a/src/sql/platform/connection/common/connectionManagement.ts +++ b/src/sql/platform/connection/common/connectionManagement.ts @@ -171,6 +171,7 @@ export interface IConnectionManagementService { isConnected(fileUri: string): boolean; + refreshAzureAccountTokenIfNecessary(uri: string): Promise; /** * Returns true if the connection profile is connected */ diff --git a/src/sql/platform/connection/common/connectionProfile.ts b/src/sql/platform/connection/common/connectionProfile.ts index d308911fcc..13751a91b8 100644 --- a/src/sql/platform/connection/common/connectionProfile.ts +++ b/src/sql/platform/connection/common/connectionProfile.ts @@ -66,6 +66,10 @@ export class ConnectionProfile extends ProviderConnectionInfo implements interfa if (model.options.registeredServerDescription) { this.registeredServerDescription = model.options.registeredServerDescription; } + const expiry = model.options.expiresOn; + if (typeof expiry === 'number' && !Number.isNaN(expiry)) { + this.options.expiresOn = model.options.expiresOn; + } } } else { //Default for a new connection diff --git a/src/sql/platform/connection/test/common/testConnectionManagementService.ts b/src/sql/platform/connection/test/common/testConnectionManagementService.ts index b317d06d1d..84c56702fd 100644 --- a/src/sql/platform/connection/test/common/testConnectionManagementService.ts +++ b/src/sql/platform/connection/test/common/testConnectionManagementService.ts @@ -309,4 +309,8 @@ export class TestConnectionManagementService implements IConnectionManagementSer getConnection(uri: string): ConnectionProfile { return undefined!; } + + refreshAzureAccountTokenIfNecessary(uri: string): Promise { + return undefined; + } } diff --git a/src/sql/workbench/contrib/query/browser/queryActions.ts b/src/sql/workbench/contrib/query/browser/queryActions.ts index 422381d6ca..b3eae39870 100644 --- a/src/sql/workbench/contrib/query/browser/queryActions.ts +++ b/src/sql/workbench/contrib/query/browser/queryActions.ts @@ -206,6 +206,7 @@ export class RunQueryAction extends QueryTaskbarAction { public override async run(): Promise { if (!this.editor.isSelectionEmpty()) { + await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri); if (this.isConnected(this.editor)) { // If we are already connected, run the query this.runQuery(this.editor); @@ -220,6 +221,7 @@ export class RunQueryAction extends QueryTaskbarAction { public async runCurrent(): Promise { if (!this.editor.isSelectionEmpty()) { + await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri); if (this.isConnected(this.editor)) { // If we are already connected, run the query this.runQuery(this.editor, true); @@ -307,6 +309,7 @@ export class EstimatedQueryPlanAction extends QueryTaskbarAction { public override async run(): Promise { if (!this.editor.isSelectionEmpty()) { + await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri); if (this.isConnected(this.editor)) { // If we are already connected, run the query this.runQuery(this.editor); @@ -346,6 +349,7 @@ export class ActualQueryPlanAction extends QueryTaskbarAction { public override async run(): Promise { if (!this.editor.isSelectionEmpty()) { + await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri); if (this.isConnected(this.editor)) { // If we are already connected, run the query this.runQuery(this.editor); diff --git a/src/sql/workbench/services/connection/browser/connectionManagementService.ts b/src/sql/workbench/services/connection/browser/connectionManagementService.ts index 6a14fa9ab1..52e9530f57 100644 --- a/src/sql/workbench/services/connection/browser/connectionManagementService.ts +++ b/src/sql/workbench/services/connection/browser/connectionManagementService.ts @@ -66,6 +66,7 @@ export class ConnectionManagementService extends Disposable implements IConnecti private _onConnectionChanged = new Emitter(); private _onLanguageFlavorChanged = new Emitter(); private _connectionGlobalStatus = new ConnectionGlobalStatus(this._notificationService); + private _uriToReconnectPromiseMap: { [uri: string]: Promise } = {}; private _mementoContext: Memento; private _mementoObj: MementoObject; @@ -863,6 +864,7 @@ export class ConnectionManagementService extends Disposable implements IConnecti this._logService.info(`No security tokens found for account`); } connection.options['azureAccountToken'] = token.token; + connection.options['expiresOn'] = token.expiresOn; connection.options['password'] = ''; return true; } else { @@ -874,6 +876,62 @@ export class ConnectionManagementService extends Disposable implements IConnecti return false; } + /** + * Refresh Azure access token if it's expired. + * @param uri connection uri + * @returns true if no need to refresh or successfully refreshed token + */ + public async refreshAzureAccountTokenIfNecessary(uri: string): Promise { + const profile = this._connectionStatusManager.getConnectionProfile(uri); + if (!profile) { + this._logService.warn(`Connection not found for uri ${uri}`); + return false; + } + + //wait for the pending reconnction promise if any + const previousReconnectPromise = this._uriToReconnectPromiseMap[uri]; + if (previousReconnectPromise) { + this._logService.info(`Found pending reconnect promise for uri ${uri}, waiting.`); + try { + const previousConnectionResult = await previousReconnectPromise; + if (previousConnectionResult && previousConnectionResult.connected) { + this._logService.info(`Previous pending reconnection for uri ${uri} succeeded.`); + return true; + } + this._logService.info(`Previous pending reconnection for uri ${uri} failed.`); + } catch (err) { + this._logService.info(`Previous pending reconnect promise for uri ${uri} is rejected with error ${err}, will attempt to reconnect if necessary.`); + } + } + + const expiry = profile.options.expiresOn; + if (typeof expiry === 'number' && !Number.isNaN(expiry)) { + const currentTime = new Date().getTime() / 1000; + const maxTolerance = 2 * 60; // two minutes + if (expiry - currentTime < maxTolerance) { + this._logService.info(`Access token expired for connection ${profile.id} with uri ${uri}`); + try { + const connectionResultPromise = this.connect(profile, uri); + this._uriToReconnectPromiseMap[uri] = connectionResultPromise; + const connectionResult = await connectionResultPromise; + if (!connectionResult) { + this._logService.error(`Failed to refresh connection ${profile.id} with uri ${uri}, invalid connection result.`); + throw new Error(nls.localize('connection.invalidConnectionResult', "Connection result is invalid")); + } else if (!connectionResult.connected) { + this._logService.error(`Failed to refresh connection ${profile.id} with uri ${uri}, error code: ${connectionResult.errorCode}, error message: ${connectionResult.errorMessage}`); + throw new Error(nls.localize('connection.refreshAzureTokenFailure', "Failed to refresh Azure account token for connection")); + } + this._logService.info(`Successfully refreshed token for connection ${profile.id} with uri ${uri}, result: ${connectionResult.connected} ${connectionResult.connectionProfile}, isConnected: ${this.isConnected(uri)}, ${this._connectionStatusManager.getConnectionProfile(uri)}`); + return true; + } finally { + delete this._uriToReconnectPromiseMap[uri]; + } + } + this._logService.info(`No need to refresh Azure acccount token for connection ${profile.id} with uri ${uri}`); + } + return true; + } + // Request Senders private async sendConnectRequest(connection: interfaces.IConnectionProfile, uri: string): Promise { let connectionInfo = Object.assign({}, { @@ -1240,8 +1298,9 @@ export class ConnectionManagementService extends Disposable implements IConnecti return this._connectionStatusManager.isConnected(fileUri) ? this._connectionStatusManager.findConnection(fileUri) : undefined; } - public listDatabases(connectionUri: string): Thenable { + public async listDatabases(connectionUri: string): Promise { const self = this; + await this.refreshAzureAccountTokenIfNecessary(connectionUri); if (self.isConnected(connectionUri)) { return self.sendListDatabasesRequest(connectionUri); } diff --git a/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts b/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts index ce1d705568..ac9f0d35dc 100644 --- a/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts +++ b/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts @@ -120,7 +120,12 @@ suite('SQL ConnectionManagementService tests', () => { connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is( c => c.serverName === connectionProfileWithEmptyUnsavedPassword.serverName))).returns( () => Promise.resolve({ profile: connectionProfileWithEmptyUnsavedPassword, savedCred: false })); - connectionStore.setup(x => x.isPasswordRequired(TypeMoq.It.isAny())).returns(() => true); + connectionStore.setup(x => x.isPasswordRequired(TypeMoq.It.isAny())).returns((profile) => { + if (profile.authenticationType === Constants.azureMFA) { + return false; + } + return true; + }); connectionStore.setup(x => x.getConnectionProfileGroups(false, undefined)).returns(() => [root]); connectionStore.setup(x => x.savePassword(TypeMoq.It.isAny())).returns(() => Promise.resolve(true)); @@ -1693,6 +1698,78 @@ suite('SQL ConnectionManagementService tests', () => { assert.strictEqual(profileWithCredentials.options['azureAccountToken'], testToken); }); + test('refreshAzureAccountTokenIfNecessary refreshes Azure access token if existing token is expired', async () => { + const uri: string = 'Editor Uri'; + // Set up a connection profile that uses Azure + const azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile); + azureConnectionProfile.authenticationType = 'AzureMFA'; + const username = 'testuser@microsoft.com'; + azureConnectionProfile.azureAccount = username; + const servername = 'test-database.database.windows.net'; + azureConnectionProfile.serverName = servername; + const providerId = 'azure_PublicCloud'; + azureConnectionProfile.azureTenantId = 'testTenant'; + + const expiredToken = { + token: 'expiredToken', + tokenType: 'Bearer', + expiresOn: 0, + }; + + const freshToken = { + token: 'freshToken', + tokenType: 'Bearer', + expiresOn: new Date().getTime() / 1000 + 7200, + }; + + // every connectionStatusManager.connect will call accountManagementService.getAccountSecurityToken twice + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken)); + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken)); + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(freshToken)); + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(freshToken)); + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken)); + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken)); + + accountManagementService.setup(x => x.getAccounts()).returns(() => { + return Promise.resolve([ + { + key: { + accountId: username, + providerId: providerId + }, + displayInfo: undefined, + isStale: false, + properties: undefined + } + ]); + }); + + connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({ + profile: profile, + savedCred: false + })); + + (connectionManagementService as any)._connectionStatusManager = connectionStatusManager; + await connect(uri, undefined, false, azureConnectionProfile); + + const oldProfile = connectionStatusManager.getConnectionProfile(uri); + assert.strictEqual(oldProfile.options['expiresOn'], expiredToken.expiresOn); + + const refreshRes1 = await connectionManagementService.refreshAzureAccountTokenIfNecessary(uri); + assert.strictEqual(refreshRes1, true); + + // first refresh should give us the new token + const newProfile1 = connectionStatusManager.getConnectionProfile(uri); + assert.strictEqual(newProfile1.options['expiresOn'], freshToken.expiresOn); + + const refreshRes2 = await connectionManagementService.refreshAzureAccountTokenIfNecessary(uri); + assert.strictEqual(refreshRes2, true); + + // second refresh should be a no-op + const newProfile2 = connectionStatusManager.getConnectionProfile(uri); + assert.strictEqual(newProfile2.options['expiresOn'], freshToken.expiresOn); + }); + test('addSavedPassword fills in Azure access token for selected tenant', async () => { // Set up a connection profile that uses Azure let azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile);