Add token expiration handling for AzureMFA auth (#16936)

* refresh azure account token if it's expired before sending query/list requests

* fix several connection checks && add more logging

* fix async UI glitch during token refreshing

* cleanup

* minor fix

* add test for refreshAzureAccountTokenIfNecessary

* address comments

* comments

* comments

* comments

* error handling
This commit is contained in:
Hai Cao
2021-09-02 17:24:45 -07:00
committed by GitHub
parent 73c20345e9
commit 0bc2a50d78
9 changed files with 171 additions and 4 deletions

View File

@@ -206,6 +206,7 @@ export class RunQueryAction extends QueryTaskbarAction {
public override async run(): Promise<void> {
if (!this.editor.isSelectionEmpty()) {
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
if (this.isConnected(this.editor)) {
// If we are already connected, run the query
this.runQuery(this.editor);
@@ -220,6 +221,7 @@ export class RunQueryAction extends QueryTaskbarAction {
public async runCurrent(): Promise<void> {
if (!this.editor.isSelectionEmpty()) {
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
if (this.isConnected(this.editor)) {
// If we are already connected, run the query
this.runQuery(this.editor, true);
@@ -307,6 +309,7 @@ export class EstimatedQueryPlanAction extends QueryTaskbarAction {
public override async run(): Promise<void> {
if (!this.editor.isSelectionEmpty()) {
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
if (this.isConnected(this.editor)) {
// If we are already connected, run the query
this.runQuery(this.editor);
@@ -346,6 +349,7 @@ export class ActualQueryPlanAction extends QueryTaskbarAction {
public override async run(): Promise<void> {
if (!this.editor.isSelectionEmpty()) {
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
if (this.isConnected(this.editor)) {
// If we are already connected, run the query
this.runQuery(this.editor);

View File

@@ -66,6 +66,7 @@ export class ConnectionManagementService extends Disposable implements IConnecti
private _onConnectionChanged = new Emitter<IConnectionParams>();
private _onLanguageFlavorChanged = new Emitter<azdata.DidChangeLanguageFlavorParams>();
private _connectionGlobalStatus = new ConnectionGlobalStatus(this._notificationService);
private _uriToReconnectPromiseMap: { [uri: string]: Promise<IConnectionResult> } = {};
private _mementoContext: Memento;
private _mementoObj: MementoObject;
@@ -863,6 +864,7 @@ export class ConnectionManagementService extends Disposable implements IConnecti
this._logService.info(`No security tokens found for account`);
}
connection.options['azureAccountToken'] = token.token;
connection.options['expiresOn'] = token.expiresOn;
connection.options['password'] = '';
return true;
} else {
@@ -874,6 +876,62 @@ export class ConnectionManagementService extends Disposable implements IConnecti
return false;
}
/**
* Refresh Azure access token if it's expired.
* @param uri connection uri
* @returns true if no need to refresh or successfully refreshed token
*/
public async refreshAzureAccountTokenIfNecessary(uri: string): Promise<boolean> {
const profile = this._connectionStatusManager.getConnectionProfile(uri);
if (!profile) {
this._logService.warn(`Connection not found for uri ${uri}`);
return false;
}
//wait for the pending reconnction promise if any
const previousReconnectPromise = this._uriToReconnectPromiseMap[uri];
if (previousReconnectPromise) {
this._logService.info(`Found pending reconnect promise for uri ${uri}, waiting.`);
try {
const previousConnectionResult = await previousReconnectPromise;
if (previousConnectionResult && previousConnectionResult.connected) {
this._logService.info(`Previous pending reconnection for uri ${uri} succeeded.`);
return true;
}
this._logService.info(`Previous pending reconnection for uri ${uri} failed.`);
} catch (err) {
this._logService.info(`Previous pending reconnect promise for uri ${uri} is rejected with error ${err}, will attempt to reconnect if necessary.`);
}
}
const expiry = profile.options.expiresOn;
if (typeof expiry === 'number' && !Number.isNaN(expiry)) {
const currentTime = new Date().getTime() / 1000;
const maxTolerance = 2 * 60; // two minutes
if (expiry - currentTime < maxTolerance) {
this._logService.info(`Access token expired for connection ${profile.id} with uri ${uri}`);
try {
const connectionResultPromise = this.connect(profile, uri);
this._uriToReconnectPromiseMap[uri] = connectionResultPromise;
const connectionResult = await connectionResultPromise;
if (!connectionResult) {
this._logService.error(`Failed to refresh connection ${profile.id} with uri ${uri}, invalid connection result.`);
throw new Error(nls.localize('connection.invalidConnectionResult', "Connection result is invalid"));
} else if (!connectionResult.connected) {
this._logService.error(`Failed to refresh connection ${profile.id} with uri ${uri}, error code: ${connectionResult.errorCode}, error message: ${connectionResult.errorMessage}`);
throw new Error(nls.localize('connection.refreshAzureTokenFailure', "Failed to refresh Azure account token for connection"));
}
this._logService.info(`Successfully refreshed token for connection ${profile.id} with uri ${uri}, result: ${connectionResult.connected} ${connectionResult.connectionProfile}, isConnected: ${this.isConnected(uri)}, ${this._connectionStatusManager.getConnectionProfile(uri)}`);
return true;
} finally {
delete this._uriToReconnectPromiseMap[uri];
}
}
this._logService.info(`No need to refresh Azure acccount token for connection ${profile.id} with uri ${uri}`);
}
return true;
}
// Request Senders
private async sendConnectRequest(connection: interfaces.IConnectionProfile, uri: string): Promise<boolean> {
let connectionInfo = Object.assign({}, {
@@ -1240,8 +1298,9 @@ export class ConnectionManagementService extends Disposable implements IConnecti
return this._connectionStatusManager.isConnected(fileUri) ? this._connectionStatusManager.findConnection(fileUri) : undefined;
}
public listDatabases(connectionUri: string): Thenable<azdata.ListDatabasesResult | undefined> {
public async listDatabases(connectionUri: string): Promise<azdata.ListDatabasesResult | undefined> {
const self = this;
await this.refreshAzureAccountTokenIfNecessary(connectionUri);
if (self.isConnected(connectionUri)) {
return self.sendListDatabasesRequest(connectionUri);
}

View File

@@ -120,7 +120,12 @@ suite('SQL ConnectionManagementService tests', () => {
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is<IConnectionProfile>(
c => c.serverName === connectionProfileWithEmptyUnsavedPassword.serverName))).returns(
() => Promise.resolve({ profile: connectionProfileWithEmptyUnsavedPassword, savedCred: false }));
connectionStore.setup(x => x.isPasswordRequired(TypeMoq.It.isAny())).returns(() => true);
connectionStore.setup(x => x.isPasswordRequired(TypeMoq.It.isAny())).returns((profile) => {
if (profile.authenticationType === Constants.azureMFA) {
return false;
}
return true;
});
connectionStore.setup(x => x.getConnectionProfileGroups(false, undefined)).returns(() => [root]);
connectionStore.setup(x => x.savePassword(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
@@ -1693,6 +1698,78 @@ suite('SQL ConnectionManagementService tests', () => {
assert.strictEqual(profileWithCredentials.options['azureAccountToken'], testToken);
});
test('refreshAzureAccountTokenIfNecessary refreshes Azure access token if existing token is expired', async () => {
const uri: string = 'Editor Uri';
// Set up a connection profile that uses Azure
const azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile);
azureConnectionProfile.authenticationType = 'AzureMFA';
const username = 'testuser@microsoft.com';
azureConnectionProfile.azureAccount = username;
const servername = 'test-database.database.windows.net';
azureConnectionProfile.serverName = servername;
const providerId = 'azure_PublicCloud';
azureConnectionProfile.azureTenantId = 'testTenant';
const expiredToken = {
token: 'expiredToken',
tokenType: 'Bearer',
expiresOn: 0,
};
const freshToken = {
token: 'freshToken',
tokenType: 'Bearer',
expiresOn: new Date().getTime() / 1000 + 7200,
};
// every connectionStatusManager.connect will call accountManagementService.getAccountSecurityToken twice
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(freshToken));
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(freshToken));
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
accountManagementService.setup(x => x.getAccounts()).returns(() => {
return Promise.resolve<azdata.Account[]>([
{
key: {
accountId: username,
providerId: providerId
},
displayInfo: undefined,
isStale: false,
properties: undefined
}
]);
});
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({
profile: profile,
savedCred: false
}));
(connectionManagementService as any)._connectionStatusManager = connectionStatusManager;
await connect(uri, undefined, false, azureConnectionProfile);
const oldProfile = connectionStatusManager.getConnectionProfile(uri);
assert.strictEqual(oldProfile.options['expiresOn'], expiredToken.expiresOn);
const refreshRes1 = await connectionManagementService.refreshAzureAccountTokenIfNecessary(uri);
assert.strictEqual(refreshRes1, true);
// first refresh should give us the new token
const newProfile1 = connectionStatusManager.getConnectionProfile(uri);
assert.strictEqual(newProfile1.options['expiresOn'], freshToken.expiresOn);
const refreshRes2 = await connectionManagementService.refreshAzureAccountTokenIfNecessary(uri);
assert.strictEqual(refreshRes2, true);
// second refresh should be a no-op
const newProfile2 = connectionStatusManager.getConnectionProfile(uri);
assert.strictEqual(newProfile2.options['expiresOn'], freshToken.expiresOn);
});
test('addSavedPassword fills in Azure access token for selected tenant', async () => {
// Set up a connection profile that uses Azure
let azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile);