diff --git a/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts b/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts index 46005ce7c4..19a8765747 100644 --- a/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts +++ b/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts @@ -116,14 +116,15 @@ export class MsalCachePluginProvider { */ public async writeTokenToLocalCache(token: Token): Promise { let updateCount = 0; + let indexToUpdate = -1; let cache: LocalAccountCache; cache = JSON.parse(await this.readCache(this._localCacheConfiguration)) as LocalAccountCache; if (cache?.tokens) { - cache.tokens.forEach(t => { + cache.tokens.forEach((t, i) => { if (t.key === token.key && t.tenantId === token.tenantId && t.resource === token.resource ) { // Update token - t = token; + indexToUpdate = i; updateCount++; } }); @@ -139,6 +140,9 @@ export class MsalCachePluginProvider { } if (updateCount === 1) { + if (indexToUpdate !== -1) { + cache.tokens[indexToUpdate] = token; + } await this.writeCache(JSON.stringify(cache), this._localCacheConfiguration); } else { diff --git a/src/sql/platform/connection/common/constants.ts b/src/sql/platform/connection/common/constants.ts index 5419e89a7c..b6fc242f2f 100644 --- a/src/sql/platform/connection/common/constants.ts +++ b/src/sql/platform/connection/common/constants.ts @@ -22,6 +22,9 @@ export const defaultEngine = 'defaultEngine'; export const passwordChars = '***************'; +export const enableSqlAuthenticationProviderConfig = 'mssql.enableSqlAuthenticationProvider'; + + /* default authentication type setting name*/ export const defaultAuthenticationType = 'defaultAuthenticationType'; diff --git a/src/sql/platform/connection/common/utils.ts b/src/sql/platform/connection/common/utils.ts index d5888f5ce1..9f0ed7ce54 100644 --- a/src/sql/platform/connection/common/utils.ts +++ b/src/sql/platform/connection/common/utils.ts @@ -6,7 +6,7 @@ import { IConnectionProfile } from 'sql/platform/connection/common/interfaces'; import { ConnectionProfile } from 'sql/platform/connection/common/connectionProfile'; import { ConnectionProfileGroup } from 'sql/platform/connection/common/connectionProfileGroup'; -import * as sqlExtHostTypes from 'sql/workbench/api/common/sqlExtHostTypes' +import * as sqlExtHostTypes from 'sql/workbench/api/common/sqlExtHostTypes'; // CONSTANTS ////////////////////////////////////////////////////////////////////////////////////// const msInH = 3.6e6; diff --git a/src/sql/workbench/services/connection/browser/connectionManagementService.ts b/src/sql/workbench/services/connection/browser/connectionManagementService.ts index 1f8ecb6e8a..d1c5f2e5fc 100644 --- a/src/sql/workbench/services/connection/browser/connectionManagementService.ts +++ b/src/sql/workbench/services/connection/browser/connectionManagementService.ts @@ -57,6 +57,7 @@ import { VIEWLET_ID as ExtensionsViewletID } from 'vs/workbench/contrib/extensio import { IDialogService } from 'vs/platform/dialogs/common/dialogs'; import { IErrorDiagnosticsService } from 'sql/workbench/services/diagnostics/common/errorDiagnosticsService'; import { PasswordChangeDialog } from 'sql/workbench/services/connection/browser/passwordChangeDialog'; +import { enableSqlAuthenticationProviderConfig, mssqlProviderName } from 'sql/platform/connection/common/constants'; export class ConnectionManagementService extends Disposable implements IConnectionManagementService { @@ -1141,6 +1142,12 @@ export class ConnectionManagementService extends Disposable implements IConnecti // We expect connectionProfile to be defined if (connectionProfile && connectionProfile.authenticationType === Constants.AuthenticationType.AzureMFA) { + // We do not need to reconnect for MSSQL Provider, if 'SQL Authentication Provider' setting is enabled. + // Update the token in case it needs refreshing/reauthentication. + if (connectionProfile.providerName === mssqlProviderName && this.getEnableSqlAuthenticationProviderConfig()) { + await this.fillInOrClearToken(connectionProfile); + return true; + } const expiry = connectionProfile.options.expiresOn; if (typeof expiry === 'number' && !Number.isNaN(expiry)) { const currentTime = new Date().getTime() / 1000; @@ -1172,8 +1179,13 @@ export class ConnectionManagementService extends Disposable implements IConnecti } return true; } - else + else { return false; + } + } + + private getEnableSqlAuthenticationProviderConfig(): boolean { + return this._configurationService.getValue(enableSqlAuthenticationProviderConfig) ?? true; } // Request Senders @@ -1598,10 +1610,9 @@ export class ConnectionManagementService extends Disposable implements IConnecti } public async listDatabases(connectionUri: string): Promise { - const self = this; await this.refreshAzureAccountTokenIfNecessary(connectionUri); - if (self.isConnected(connectionUri)) { - return self.sendListDatabasesRequest(connectionUri); + if (this.isConnected(connectionUri)) { + return this.sendListDatabasesRequest(connectionUri); } return Promise.resolve(undefined); }