mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
Add token expiration handling for AzureMFA auth (#16936)
* refresh azure account token if it's expired before sending query/list requests * fix several connection checks && add more logging * fix async UI glitch during token refreshing * cleanup * minor fix * add test for refreshAzureAccountTokenIfNecessary * address comments * comments * comments * comments * error handling
This commit is contained in:
@@ -184,17 +184,20 @@ export abstract class AzureAuth implements vscode.Disposable {
|
|||||||
const currentTime = new Date().getTime() / 1000;
|
const currentTime = new Date().getTime() / 1000;
|
||||||
|
|
||||||
let accessToken = cachedTokens.accessToken;
|
let accessToken = cachedTokens.accessToken;
|
||||||
|
let expiresOn = Number(cachedTokens.expiresOn);
|
||||||
const remainingTime = expiry - currentTime;
|
const remainingTime = expiry - currentTime;
|
||||||
const maxTolerance = 2 * 60; // two minutes
|
const maxTolerance = 2 * 60; // two minutes
|
||||||
|
|
||||||
if (remainingTime < maxTolerance) {
|
if (remainingTime < maxTolerance) {
|
||||||
const result = await this.refreshToken(tenant, resource, cachedTokens.refreshToken);
|
const result = await this.refreshToken(tenant, resource, cachedTokens.refreshToken);
|
||||||
accessToken = result.accessToken;
|
accessToken = result.accessToken;
|
||||||
|
expiresOn = Number(result.expiresOn);
|
||||||
}
|
}
|
||||||
// Let's just return here.
|
// Let's just return here.
|
||||||
if (accessToken) {
|
if (accessToken) {
|
||||||
return {
|
return {
|
||||||
...accessToken,
|
...accessToken,
|
||||||
|
expiresOn: expiresOn,
|
||||||
tokenType: 'Bearer'
|
tokenType: 'Bearer'
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -214,6 +217,7 @@ export abstract class AzureAuth implements vscode.Disposable {
|
|||||||
if (result.accessToken) {
|
if (result.accessToken) {
|
||||||
return {
|
return {
|
||||||
...result.accessToken,
|
...result.accessToken,
|
||||||
|
expiresOn: Number(result.expiresOn),
|
||||||
tokenType: 'Bearer'
|
tokenType: 'Bearer'
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -674,6 +678,11 @@ export interface Token extends AccountKey {
|
|||||||
*/
|
*/
|
||||||
token: string;
|
token: string;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Access token expiry timestamp
|
||||||
|
*/
|
||||||
|
expiresOn?: number;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TokenType
|
* TokenType
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -96,8 +96,8 @@ describe('Azure Authentication', function () {
|
|||||||
it('token recieved for ossRdbmns resource', async function () {
|
it('token recieved for ossRdbmns resource', async function () {
|
||||||
azureAuthCodeGrant.setup(x => x.getTenants(mockToken)).returns(() => {
|
azureAuthCodeGrant.setup(x => x.getTenants(mockToken)).returns(() => {
|
||||||
return Promise.resolve([
|
return Promise.resolve([
|
||||||
mockTenant
|
mockTenant
|
||||||
]);
|
]);
|
||||||
});
|
});
|
||||||
azureAuthCodeGrant.setup(x => x.getTokenHelper(mockTenant, provider.settings.ossRdbmsResource, TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
azureAuthCodeGrant.setup(x => x.getTokenHelper(mockTenant, provider.settings.ossRdbmsResource, TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||||
return Promise.resolve({
|
return Promise.resolve({
|
||||||
|
|||||||
9
src/sql/azdata.proposed.d.ts
vendored
9
src/sql/azdata.proposed.d.ts
vendored
@@ -947,4 +947,13 @@ declare module 'azdata' {
|
|||||||
*/
|
*/
|
||||||
parentTypeName?: string;
|
parentTypeName?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export namespace accounts {
|
||||||
|
export interface AccountSecurityToken {
|
||||||
|
/**
|
||||||
|
* Access token expiry timestamp
|
||||||
|
*/
|
||||||
|
expiresOn?: number
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ export interface IConnectionManagementService {
|
|||||||
|
|
||||||
isConnected(fileUri: string): boolean;
|
isConnected(fileUri: string): boolean;
|
||||||
|
|
||||||
|
refreshAzureAccountTokenIfNecessary(uri: string): Promise<boolean>;
|
||||||
/**
|
/**
|
||||||
* Returns true if the connection profile is connected
|
* Returns true if the connection profile is connected
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ export class ConnectionProfile extends ProviderConnectionInfo implements interfa
|
|||||||
if (model.options.registeredServerDescription) {
|
if (model.options.registeredServerDescription) {
|
||||||
this.registeredServerDescription = model.options.registeredServerDescription;
|
this.registeredServerDescription = model.options.registeredServerDescription;
|
||||||
}
|
}
|
||||||
|
const expiry = model.options.expiresOn;
|
||||||
|
if (typeof expiry === 'number' && !Number.isNaN(expiry)) {
|
||||||
|
this.options.expiresOn = model.options.expiresOn;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Default for a new connection
|
//Default for a new connection
|
||||||
|
|||||||
@@ -309,4 +309,8 @@ export class TestConnectionManagementService implements IConnectionManagementSer
|
|||||||
getConnection(uri: string): ConnectionProfile {
|
getConnection(uri: string): ConnectionProfile {
|
||||||
return undefined!;
|
return undefined!;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refreshAzureAccountTokenIfNecessary(uri: string): Promise<boolean> {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -206,6 +206,7 @@ export class RunQueryAction extends QueryTaskbarAction {
|
|||||||
|
|
||||||
public override async run(): Promise<void> {
|
public override async run(): Promise<void> {
|
||||||
if (!this.editor.isSelectionEmpty()) {
|
if (!this.editor.isSelectionEmpty()) {
|
||||||
|
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
|
||||||
if (this.isConnected(this.editor)) {
|
if (this.isConnected(this.editor)) {
|
||||||
// If we are already connected, run the query
|
// If we are already connected, run the query
|
||||||
this.runQuery(this.editor);
|
this.runQuery(this.editor);
|
||||||
@@ -220,6 +221,7 @@ export class RunQueryAction extends QueryTaskbarAction {
|
|||||||
|
|
||||||
public async runCurrent(): Promise<void> {
|
public async runCurrent(): Promise<void> {
|
||||||
if (!this.editor.isSelectionEmpty()) {
|
if (!this.editor.isSelectionEmpty()) {
|
||||||
|
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
|
||||||
if (this.isConnected(this.editor)) {
|
if (this.isConnected(this.editor)) {
|
||||||
// If we are already connected, run the query
|
// If we are already connected, run the query
|
||||||
this.runQuery(this.editor, true);
|
this.runQuery(this.editor, true);
|
||||||
@@ -307,6 +309,7 @@ export class EstimatedQueryPlanAction extends QueryTaskbarAction {
|
|||||||
|
|
||||||
public override async run(): Promise<void> {
|
public override async run(): Promise<void> {
|
||||||
if (!this.editor.isSelectionEmpty()) {
|
if (!this.editor.isSelectionEmpty()) {
|
||||||
|
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
|
||||||
if (this.isConnected(this.editor)) {
|
if (this.isConnected(this.editor)) {
|
||||||
// If we are already connected, run the query
|
// If we are already connected, run the query
|
||||||
this.runQuery(this.editor);
|
this.runQuery(this.editor);
|
||||||
@@ -346,6 +349,7 @@ export class ActualQueryPlanAction extends QueryTaskbarAction {
|
|||||||
|
|
||||||
public override async run(): Promise<void> {
|
public override async run(): Promise<void> {
|
||||||
if (!this.editor.isSelectionEmpty()) {
|
if (!this.editor.isSelectionEmpty()) {
|
||||||
|
await this.connectionManagementService.refreshAzureAccountTokenIfNecessary(this.editor.input.uri);
|
||||||
if (this.isConnected(this.editor)) {
|
if (this.isConnected(this.editor)) {
|
||||||
// If we are already connected, run the query
|
// If we are already connected, run the query
|
||||||
this.runQuery(this.editor);
|
this.runQuery(this.editor);
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ export class ConnectionManagementService extends Disposable implements IConnecti
|
|||||||
private _onConnectionChanged = new Emitter<IConnectionParams>();
|
private _onConnectionChanged = new Emitter<IConnectionParams>();
|
||||||
private _onLanguageFlavorChanged = new Emitter<azdata.DidChangeLanguageFlavorParams>();
|
private _onLanguageFlavorChanged = new Emitter<azdata.DidChangeLanguageFlavorParams>();
|
||||||
private _connectionGlobalStatus = new ConnectionGlobalStatus(this._notificationService);
|
private _connectionGlobalStatus = new ConnectionGlobalStatus(this._notificationService);
|
||||||
|
private _uriToReconnectPromiseMap: { [uri: string]: Promise<IConnectionResult> } = {};
|
||||||
|
|
||||||
private _mementoContext: Memento;
|
private _mementoContext: Memento;
|
||||||
private _mementoObj: MementoObject;
|
private _mementoObj: MementoObject;
|
||||||
@@ -863,6 +864,7 @@ export class ConnectionManagementService extends Disposable implements IConnecti
|
|||||||
this._logService.info(`No security tokens found for account`);
|
this._logService.info(`No security tokens found for account`);
|
||||||
}
|
}
|
||||||
connection.options['azureAccountToken'] = token.token;
|
connection.options['azureAccountToken'] = token.token;
|
||||||
|
connection.options['expiresOn'] = token.expiresOn;
|
||||||
connection.options['password'] = '';
|
connection.options['password'] = '';
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
@@ -874,6 +876,62 @@ export class ConnectionManagementService extends Disposable implements IConnecti
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refresh Azure access token if it's expired.
|
||||||
|
* @param uri connection uri
|
||||||
|
* @returns true if no need to refresh or successfully refreshed token
|
||||||
|
*/
|
||||||
|
public async refreshAzureAccountTokenIfNecessary(uri: string): Promise<boolean> {
|
||||||
|
const profile = this._connectionStatusManager.getConnectionProfile(uri);
|
||||||
|
if (!profile) {
|
||||||
|
this._logService.warn(`Connection not found for uri ${uri}`);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//wait for the pending reconnction promise if any
|
||||||
|
const previousReconnectPromise = this._uriToReconnectPromiseMap[uri];
|
||||||
|
if (previousReconnectPromise) {
|
||||||
|
this._logService.info(`Found pending reconnect promise for uri ${uri}, waiting.`);
|
||||||
|
try {
|
||||||
|
const previousConnectionResult = await previousReconnectPromise;
|
||||||
|
if (previousConnectionResult && previousConnectionResult.connected) {
|
||||||
|
this._logService.info(`Previous pending reconnection for uri ${uri} succeeded.`);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
this._logService.info(`Previous pending reconnection for uri ${uri} failed.`);
|
||||||
|
} catch (err) {
|
||||||
|
this._logService.info(`Previous pending reconnect promise for uri ${uri} is rejected with error ${err}, will attempt to reconnect if necessary.`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const expiry = profile.options.expiresOn;
|
||||||
|
if (typeof expiry === 'number' && !Number.isNaN(expiry)) {
|
||||||
|
const currentTime = new Date().getTime() / 1000;
|
||||||
|
const maxTolerance = 2 * 60; // two minutes
|
||||||
|
if (expiry - currentTime < maxTolerance) {
|
||||||
|
this._logService.info(`Access token expired for connection ${profile.id} with uri ${uri}`);
|
||||||
|
try {
|
||||||
|
const connectionResultPromise = this.connect(profile, uri);
|
||||||
|
this._uriToReconnectPromiseMap[uri] = connectionResultPromise;
|
||||||
|
const connectionResult = await connectionResultPromise;
|
||||||
|
if (!connectionResult) {
|
||||||
|
this._logService.error(`Failed to refresh connection ${profile.id} with uri ${uri}, invalid connection result.`);
|
||||||
|
throw new Error(nls.localize('connection.invalidConnectionResult', "Connection result is invalid"));
|
||||||
|
} else if (!connectionResult.connected) {
|
||||||
|
this._logService.error(`Failed to refresh connection ${profile.id} with uri ${uri}, error code: ${connectionResult.errorCode}, error message: ${connectionResult.errorMessage}`);
|
||||||
|
throw new Error(nls.localize('connection.refreshAzureTokenFailure', "Failed to refresh Azure account token for connection"));
|
||||||
|
}
|
||||||
|
this._logService.info(`Successfully refreshed token for connection ${profile.id} with uri ${uri}, result: ${connectionResult.connected} ${connectionResult.connectionProfile}, isConnected: ${this.isConnected(uri)}, ${this._connectionStatusManager.getConnectionProfile(uri)}`);
|
||||||
|
return true;
|
||||||
|
} finally {
|
||||||
|
delete this._uriToReconnectPromiseMap[uri];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this._logService.info(`No need to refresh Azure acccount token for connection ${profile.id} with uri ${uri}`);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Request Senders
|
// Request Senders
|
||||||
private async sendConnectRequest(connection: interfaces.IConnectionProfile, uri: string): Promise<boolean> {
|
private async sendConnectRequest(connection: interfaces.IConnectionProfile, uri: string): Promise<boolean> {
|
||||||
let connectionInfo = Object.assign({}, {
|
let connectionInfo = Object.assign({}, {
|
||||||
@@ -1240,8 +1298,9 @@ export class ConnectionManagementService extends Disposable implements IConnecti
|
|||||||
return this._connectionStatusManager.isConnected(fileUri) ? this._connectionStatusManager.findConnection(fileUri) : undefined;
|
return this._connectionStatusManager.isConnected(fileUri) ? this._connectionStatusManager.findConnection(fileUri) : undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
public listDatabases(connectionUri: string): Thenable<azdata.ListDatabasesResult | undefined> {
|
public async listDatabases(connectionUri: string): Promise<azdata.ListDatabasesResult | undefined> {
|
||||||
const self = this;
|
const self = this;
|
||||||
|
await this.refreshAzureAccountTokenIfNecessary(connectionUri);
|
||||||
if (self.isConnected(connectionUri)) {
|
if (self.isConnected(connectionUri)) {
|
||||||
return self.sendListDatabasesRequest(connectionUri);
|
return self.sendListDatabasesRequest(connectionUri);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -120,7 +120,12 @@ suite('SQL ConnectionManagementService tests', () => {
|
|||||||
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is<IConnectionProfile>(
|
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is<IConnectionProfile>(
|
||||||
c => c.serverName === connectionProfileWithEmptyUnsavedPassword.serverName))).returns(
|
c => c.serverName === connectionProfileWithEmptyUnsavedPassword.serverName))).returns(
|
||||||
() => Promise.resolve({ profile: connectionProfileWithEmptyUnsavedPassword, savedCred: false }));
|
() => Promise.resolve({ profile: connectionProfileWithEmptyUnsavedPassword, savedCred: false }));
|
||||||
connectionStore.setup(x => x.isPasswordRequired(TypeMoq.It.isAny())).returns(() => true);
|
connectionStore.setup(x => x.isPasswordRequired(TypeMoq.It.isAny())).returns((profile) => {
|
||||||
|
if (profile.authenticationType === Constants.azureMFA) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
connectionStore.setup(x => x.getConnectionProfileGroups(false, undefined)).returns(() => [root]);
|
connectionStore.setup(x => x.getConnectionProfileGroups(false, undefined)).returns(() => [root]);
|
||||||
connectionStore.setup(x => x.savePassword(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
connectionStore.setup(x => x.savePassword(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||||
|
|
||||||
@@ -1693,6 +1698,78 @@ suite('SQL ConnectionManagementService tests', () => {
|
|||||||
assert.strictEqual(profileWithCredentials.options['azureAccountToken'], testToken);
|
assert.strictEqual(profileWithCredentials.options['azureAccountToken'], testToken);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test('refreshAzureAccountTokenIfNecessary refreshes Azure access token if existing token is expired', async () => {
|
||||||
|
const uri: string = 'Editor Uri';
|
||||||
|
// Set up a connection profile that uses Azure
|
||||||
|
const azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile);
|
||||||
|
azureConnectionProfile.authenticationType = 'AzureMFA';
|
||||||
|
const username = 'testuser@microsoft.com';
|
||||||
|
azureConnectionProfile.azureAccount = username;
|
||||||
|
const servername = 'test-database.database.windows.net';
|
||||||
|
azureConnectionProfile.serverName = servername;
|
||||||
|
const providerId = 'azure_PublicCloud';
|
||||||
|
azureConnectionProfile.azureTenantId = 'testTenant';
|
||||||
|
|
||||||
|
const expiredToken = {
|
||||||
|
token: 'expiredToken',
|
||||||
|
tokenType: 'Bearer',
|
||||||
|
expiresOn: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
const freshToken = {
|
||||||
|
token: 'freshToken',
|
||||||
|
tokenType: 'Bearer',
|
||||||
|
expiresOn: new Date().getTime() / 1000 + 7200,
|
||||||
|
};
|
||||||
|
|
||||||
|
// every connectionStatusManager.connect will call accountManagementService.getAccountSecurityToken twice
|
||||||
|
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
|
||||||
|
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
|
||||||
|
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(freshToken));
|
||||||
|
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(freshToken));
|
||||||
|
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
|
||||||
|
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(expiredToken));
|
||||||
|
|
||||||
|
accountManagementService.setup(x => x.getAccounts()).returns(() => {
|
||||||
|
return Promise.resolve<azdata.Account[]>([
|
||||||
|
{
|
||||||
|
key: {
|
||||||
|
accountId: username,
|
||||||
|
providerId: providerId
|
||||||
|
},
|
||||||
|
displayInfo: undefined,
|
||||||
|
isStale: false,
|
||||||
|
properties: undefined
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
});
|
||||||
|
|
||||||
|
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({
|
||||||
|
profile: profile,
|
||||||
|
savedCred: false
|
||||||
|
}));
|
||||||
|
|
||||||
|
(connectionManagementService as any)._connectionStatusManager = connectionStatusManager;
|
||||||
|
await connect(uri, undefined, false, azureConnectionProfile);
|
||||||
|
|
||||||
|
const oldProfile = connectionStatusManager.getConnectionProfile(uri);
|
||||||
|
assert.strictEqual(oldProfile.options['expiresOn'], expiredToken.expiresOn);
|
||||||
|
|
||||||
|
const refreshRes1 = await connectionManagementService.refreshAzureAccountTokenIfNecessary(uri);
|
||||||
|
assert.strictEqual(refreshRes1, true);
|
||||||
|
|
||||||
|
// first refresh should give us the new token
|
||||||
|
const newProfile1 = connectionStatusManager.getConnectionProfile(uri);
|
||||||
|
assert.strictEqual(newProfile1.options['expiresOn'], freshToken.expiresOn);
|
||||||
|
|
||||||
|
const refreshRes2 = await connectionManagementService.refreshAzureAccountTokenIfNecessary(uri);
|
||||||
|
assert.strictEqual(refreshRes2, true);
|
||||||
|
|
||||||
|
// second refresh should be a no-op
|
||||||
|
const newProfile2 = connectionStatusManager.getConnectionProfile(uri);
|
||||||
|
assert.strictEqual(newProfile2.options['expiresOn'], freshToken.expiresOn);
|
||||||
|
});
|
||||||
|
|
||||||
test('addSavedPassword fills in Azure access token for selected tenant', async () => {
|
test('addSavedPassword fills in Azure access token for selected tenant', async () => {
|
||||||
// Set up a connection profile that uses Azure
|
// Set up a connection profile that uses Azure
|
||||||
let azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile);
|
let azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile);
|
||||||
|
|||||||
Reference in New Issue
Block a user