diff --git a/extensions/azurecore/src/account-provider/auths/azureAuth.ts b/extensions/azurecore/src/account-provider/auths/azureAuth.ts index 25041a2ff7..9a08c45929 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuth.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuth.ts @@ -212,8 +212,9 @@ export abstract class AzureAuth implements vscode.Disposable { } const resource = this.resources.find(s => s.azureResourceId === azureResource); + if (!resource) { - Logger.error(`Unable to find Azure resource ${azureResource} for account ${account.displayInfo.userId} and tenant ${tenantId}`); + Logger.error(`Unable to find Azure resource ${azureResource}`); return undefined; } @@ -300,7 +301,7 @@ export abstract class AzureAuth implements vscode.Disposable { * re-authentication process for their tenant. */ public async refreshTokenAdal(tenant: Tenant, resource: Resource, refreshToken: RefreshToken | undefined): Promise { - Logger.pii('Refreshing token', [{ name: 'token', objOrArray: refreshToken }], []); + Logger.piiSanitized('Refreshing token', [{ name: 'token', objOrArray: refreshToken }], []); if (refreshToken) { const postData: RefreshTokenPostData = { grant_type: 'refresh_token', @@ -324,10 +325,12 @@ export abstract class AzureAuth implements vscode.Disposable { */ public async getTokenMsal(accountId: string, azureResource: azdata.AzureResource, tenantId: string): Promise { const resource = this.resources.find(s => s.azureResourceId === azureResource); + if (!resource) { - Logger.error(`Error: Could not fetch the azure resource ${azureResource} `); + Logger.error(`Unable to find Azure resource ${azureResource}`); return null; } + // Resource endpoint must end with '/' to form a valid scope for MSAL token request. const endpoint = resource.endpoint.endsWith('/') ? resource.endpoint : resource.endpoint + '/'; @@ -399,7 +402,7 @@ export abstract class AzureAuth implements vscode.Disposable { const tokenUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/token`; const response = await this.makePostRequest(tokenUrl, postData); - Logger.pii('Token: ', [{ name: 'access token', objOrArray: response.data }, { name: 'refresh token', objOrArray: response.data }], []); + Logger.piiSanitized('Token: ', [{ name: 'access token', objOrArray: response.data }, { name: 'refresh token', objOrArray: response.data }], []); if (response.data.error === 'interaction_required') { return this.handleInteractionRequiredAdal(tenant, resource); } @@ -550,13 +553,13 @@ export abstract class AzureAuth implements vscode.Disposable { private async saveTokenAdal(tenant: Tenant, resource: Resource, accountKey: azdata.AccountKey, { accessToken, refreshToken, expiresOn }: OAuthTokenResponse) { const msg = localize('azure.cacheErrorAdd', "Error when adding your account to the cache."); if (!tenant.id || !resource.id) { - Logger.pii('Tenant ID or resource ID was undefined', [], [], tenant, resource); + Logger.piiSanitized('Tenant ID or resource ID was undefined', [], [], tenant, resource); throw new AzureAuthError(msg, 'Adding account to cache failed', undefined); } try { - Logger.pii(`Saving access token`, [{ name: 'access_token', objOrArray: accessToken }], []); + Logger.piiSanitized(`Saving access token`, [{ name: 'access_token', objOrArray: accessToken }], []); await this.tokenCache.saveCredential(`${accountKey.accountId}_access_${resource.id}_${tenant.id}`, JSON.stringify(accessToken)); - Logger.pii(`Saving refresh token`, [{ name: 'refresh_token', objOrArray: refreshToken }], []); + Logger.piiSanitized(`Saving refresh token`, [{ name: 'refresh_token', objOrArray: refreshToken }], []); await this.tokenCache.saveCredential(`${accountKey.accountId}_refresh_${resource.id}_${tenant.id}`, JSON.stringify(refreshToken)); this.memdb.set(`${accountKey.accountId}_${tenant.id}_${resource.id}`, expiresOn); } catch (ex) { @@ -570,7 +573,7 @@ export abstract class AzureAuth implements vscode.Disposable { const parseMsg = localize('azure.cacheErrorParse', "Error when parsing your account from the cache"); if (!tenant.id || !resource.id) { - Logger.pii('Tenant ID or resource ID was undefined', [], [], tenant, resource); + Logger.piiSanitized('Tenant ID or resource ID was undefined', [], [], tenant, resource); throw new AzureAuthError(getMsg, 'Getting account from cache failed', undefined); } @@ -597,7 +600,7 @@ export abstract class AzureAuth implements vscode.Disposable { if (refreshTokenString) { refreshToken = JSON.parse(refreshTokenString); } - Logger.pii('GetSavedToken ', [{ name: 'access', objOrArray: accessToken }, { name: 'refresh', objOrArray: refreshToken }], [], `expiresOn=${expiresOn}`); + Logger.piiSanitized('GetSavedToken ', [{ name: 'access', objOrArray: accessToken }, { name: 'refresh', objOrArray: refreshToken }], [], `expiresOn=${expiresOn}`); return { accessToken, refreshToken, expiresOn }; @@ -683,7 +686,7 @@ export abstract class AzureAuth implements vscode.Disposable { } }; - const messageBody = localize('azurecore.consentDialog.body', "Your tenant '{0} ({1})' requires you to re-authenticate again to access {2} resources. Press Open to start the authentication process.", tenant.displayName, tenant.id, resource.id); + const messageBody = localize('azurecore.consentDialog.body', "Your tenant '{0} ({1})' requires you to re-authenticate again to access {2} resources. Press Open to start the authentication process.", tenant.displayName, tenant.id, resource.endpoint); const result = await vscode.window.showInformationMessage(messageBody, { modal: true }, openItem, closeItem, dontAskAgainItem); if (result?.action) { @@ -788,7 +791,7 @@ export abstract class AzureAuth implements vscode.Disposable { // Intercept response and print out the response for future debugging const response = await axios.post(url, qs.stringify(postData), config); - Logger.pii('POST request ', [{ name: 'data', objOrArray: postData }, { name: 'response', objOrArray: response.data }], [], url); + Logger.piiSanitized('POST request ', [{ name: 'data', objOrArray: postData }, { name: 'response', objOrArray: response.data }], [], url); return response; } @@ -802,7 +805,7 @@ export abstract class AzureAuth implements vscode.Disposable { }; const response = await axios.get(url, config); - Logger.pii('GET request ', [{ name: 'response', objOrArray: response.data.value ?? response.data }], [], url,); + Logger.piiSanitized('GET request ', [{ name: 'response', objOrArray: response.data.value ?? response.data }], [], url,); return response; } diff --git a/extensions/azurecore/src/account-provider/azureAccountProvider.ts b/extensions/azurecore/src/account-provider/azureAccountProvider.ts index d6b04fbc1a..7b577f789e 100644 --- a/extensions/azurecore/src/account-provider/azureAccountProvider.ts +++ b/extensions/azurecore/src/account-provider/azureAccountProvider.ts @@ -144,7 +144,7 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp await this.initCompletePromise; const azureAuth = this.getAuthMethod(account); if (azureAuth) { - Logger.pii(`Getting account security token for ${JSON.stringify(account.key)} (tenant ${tenantId}). Auth Method = ${azureAuth.userFriendlyName}`, [], []); + Logger.piiSanitized(`Getting account security token for ${JSON.stringify(account.key)} (tenant ${tenantId}). Auth Method = ${azureAuth.userFriendlyName}`, [], []); if (this.authLibrary === Constants.AuthLibrary.MSAL) { tenantId = tenantId || account.properties.owningTenant.id; let authResult = await azureAuth.getTokenMsal(account.key.accountId, resource, tenantId); diff --git a/extensions/azurecore/src/account-provider/azureAccountProviderService.ts b/extensions/azurecore/src/account-provider/azureAccountProviderService.ts index 7959039553..6424e6bbd1 100644 --- a/extensions/azurecore/src/account-provider/azureAccountProviderService.ts +++ b/extensions/azurecore/src/account-provider/azureAccountProviderService.ts @@ -201,7 +201,7 @@ export class AzureAccountProviderService implements vscode.Disposable { break; } } else { - Logger.verbose(message); + Logger.pii(message); } } } diff --git a/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts b/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts index e7496dedcf..e7207c56be 100644 --- a/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts +++ b/extensions/azurecore/src/account-provider/utils/msalCachePlugin.ts @@ -24,7 +24,7 @@ export class MsalCachePluginProvider { private _lockTaken: boolean = false; private getLockfilePath(): string { - return this._msalFilePath + '.lock'; + return this._msalFilePath + '.lockfile'; } public getCachePlugin(): ICachePlugin { diff --git a/extensions/azurecore/src/utils/Logger.ts b/extensions/azurecore/src/utils/Logger.ts index faa14dc966..c673eaae96 100644 --- a/extensions/azurecore/src/utils/Logger.ts +++ b/extensions/azurecore/src/utils/Logger.ts @@ -47,7 +47,15 @@ export class Logger { this.write(LogLevel.Verbose, msg, vals); } - + /** + * Logs a message containing PII (when enabled). + * @param msg The initial message to log + */ + static pii(msg: any, ...vals: any[]) { + if (this.piiLogging) { + Logger.write(LogLevel.Pii, msg, vals); + } + } /** * Logs a message containing PII (when enabled). Provides the ability to sanitize or shorten values to hide information or reduce the amount logged. @@ -56,7 +64,7 @@ export class Logger { * @param stringsToShorten Set of strings to shorten * @param vals Any other values to add on to the end of the log message */ - static pii(msg: any, objsToSanitize: { name: string, objOrArray: any | any[] }[], stringsToShorten: { name: string, value: string }[], ...vals: any[]) { + static piiSanitized(msg: any, objsToSanitize: { name: string, objOrArray: any | any[] }[], stringsToShorten: { name: string, value: string }[], ...vals: any[]) { if (this.piiLogging) { msg = [ msg, diff --git a/extensions/mssql/package.json b/extensions/mssql/package.json index ec7a126e65..2e7be3f798 100644 --- a/extensions/mssql/package.json +++ b/extensions/mssql/package.json @@ -197,6 +197,11 @@ "default": false, "description": "%mssql.logDebugInfo%" }, + "mssql.piiLogging": { + "type": "boolean", + "default": false, + "description": "%mssql.piiLogging%" + }, "mssql.tracingLevel": { "type": "string", "description": "%mssql.tracingLevel%", @@ -413,6 +418,11 @@ "description": "%mssql.parallelMessageProcessing%", "default": false }, + "mssql.enableSqlAuthenticationProvider": { + "type": "boolean", + "description": "%mssql.enableSqlAuthenticationProvider%", + "default": false + }, "mssql.tableDesigner.preloadDatabaseModel": { "type": "boolean", "default": false, diff --git a/extensions/mssql/package.nls.json b/extensions/mssql/package.nls.json index 2c79e0cbc3..be482b8a33 100644 --- a/extensions/mssql/package.nls.json +++ b/extensions/mssql/package.nls.json @@ -35,6 +35,7 @@ "mssql.format.placeSelectStatementReferencesOnNewLine": "Should references to objects in a select statements be split into separate lines? E.g. for 'SELECT C1, C2 FROM T1' both C1 and C2 will be on separate lines", "mssql.logDebugInfo": "[Optional] Log debug output to the console (View -> Output) and then select appropriate output channel from the dropdown", "mssql.tracingLevel": "[Optional] Log level for backend services. Azure Data Studio generates a file name every time it starts and if the file already exists the logs entries are appended to that file. For cleanup of old log files see logRetentionMinutes and logFilesRemovalLimit settings. The default tracingLevel does not log much. Changing verbosity could lead to extensive logging and disk space requirements for the logs. Error includes Critical, Warning includes Error, Information includes Warning and Verbose includes Information", + "mssql.piiLogging": "[Optional] Should Personally Identifiable Information (PII) be logged in the log file.", "mssql.logRetentionMinutes": "Number of minutes to retain log files for backend services. Default is 1 week.", "mssql.logFilesRemovalLimit": "Maximum number of old files to remove upon startup that have expired mssql.logRetentionMinutes. Files that do not get cleaned up due to this limitation get cleaned up next time Azure Data Studio starts up.", "mssql.intelliSense.enableIntelliSense": "Should IntelliSense be enabled", @@ -180,7 +181,8 @@ "title.newTable": "New Table", "title.designTable": "Design", "title.changeNotebookConnection": "Change SQL Notebook Connection", - "mssql.parallelMessageProcessing" : "[Experimental] Whether the requests to the SQL Tools Service should be handled in parallel. This is introduced to discover the issues there might be when handling all requests in parallel. The default value is false. Relaunch of ADS is required when the value is changed.", + "mssql.parallelMessageProcessing" : "[Experimental] Whether the requests to the SQL Tools Service should be handled in parallel. This is introduced to discover the issues there might be when handling all requests in parallel. The default value is false. Azure Data Studio is required to be relaunched when the value is changed.", + "mssql.enableSqlAuthenticationProvider" : "Enables use of the Sql Authentication Provider for 'Active Directory Interactive' authentication mode when user selects 'AzureMFA' authentication. This enables Server-side resource endpoint integration when fetching access tokens. This option is only supported for 'MSAL' Azure Authentication Library. Azure Data Studio is required to be relaunched when the value is changed.", "mssql.tableDesigner.preloadDatabaseModel": "Whether to preload the database model when the database node in the object explorer is expanded. When enabled, the loading time of table designer can be reduced. Note: You might see higher than normal memory usage if you need to expand a lot of database nodes.", "mssql.objectExplorer.groupBySchema": "When enabled, the database objects in Object Explorer will be categorized by schema.", "mssql.objectExplorer.enableGroupBySchema":"Enable Group By Schema", diff --git a/extensions/mssql/src/contracts.ts b/extensions/mssql/src/contracts.ts index b9b9161289..4e8157c802 100644 --- a/extensions/mssql/src/contracts.ts +++ b/extensions/mssql/src/contracts.ts @@ -33,10 +33,10 @@ export class TelemetryParams { // ------------------------------- < Security Token Request > ------------------------------------------ export interface RequestSecurityTokenParams { - authority: string; provider: string; + authority: string; resource: string; - scope: string; + scopes: string[]; } export interface RequestSecurityTokenResponse { diff --git a/extensions/mssql/src/sqlToolsServer.ts b/extensions/mssql/src/sqlToolsServer.ts index 9589dfe9a3..53a1371ce8 100644 --- a/extensions/mssql/src/sqlToolsServer.ts +++ b/extensions/mssql/src/sqlToolsServer.ts @@ -9,7 +9,7 @@ import * as Constants from './constants'; import * as vscode from 'vscode'; import * as azdata from 'azdata'; import * as path from 'path'; -import { getCommonLaunchArgsAndCleanupOldLogFiles, getConfigTracingLevel, getOrDownloadServer, getParallelMessageProcessingConfig, TracingLevel } from './utils'; +import { getAzureAuthenticationLibraryConfig, getCommonLaunchArgsAndCleanupOldLogFiles, getConfigTracingLevel, getEnableSqlAuthenticationProviderConfig, getOrDownloadServer, getParallelMessageProcessingConfig, TracingLevel } from './utils'; import { TelemetryReporter, LanguageClientErrorHandler } from './telemetry'; import { SqlOpsDataClient, ClientOptions } from 'dataprotocol-client'; import { TelemetryFeature, AgentServicesFeature, SerializationFeature, AccountFeature, SqlAssessmentServicesFeature, ProfilerFeature, TableDesignerFeature, ExecutionPlanServiceFeature } from './features'; @@ -58,7 +58,7 @@ export class SqlToolsServer { const serverPath = await this.download(context); this.installDirectory = path.dirname(serverPath); const installationComplete = Date.now(); - let serverOptions = await generateServerOptions(context.extensionContext.logPath, serverPath); + let serverOptions = generateServerOptions(context.extensionContext.logPath, serverPath); let clientOptions = getClientOptions(context); this.client = new SqlOpsDataClient('mssql', Constants.serviceName, serverOptions, clientOptions); const processStart = Date.now(); @@ -117,12 +117,17 @@ export class SqlToolsServer { } } -async function generateServerOptions(logPath: string, executablePath: string): Promise { +function generateServerOptions(logPath: string, executablePath: string): ServerOptions { const launchArgs = getCommonLaunchArgsAndCleanupOldLogFiles(logPath, 'sqltools.log', executablePath); - const enableAsyncMessageProcessing = await getParallelMessageProcessingConfig(); + const enableAsyncMessageProcessing = getParallelMessageProcessingConfig(); if (enableAsyncMessageProcessing) { launchArgs.push('--parallel-message-processing'); } + const enableSqlAuthenticationProvider = getEnableSqlAuthenticationProviderConfig(); + const azureAuthLibrary = getAzureAuthenticationLibraryConfig(); + if (azureAuthLibrary === 'MSAL' && enableSqlAuthenticationProvider === true) { + launchArgs.push('--enable-sql-authentication-provider'); + } return { command: executablePath, args: launchArgs, transport: TransportKind.stdio }; } diff --git a/extensions/mssql/src/utils.ts b/extensions/mssql/src/utils.ts index 570ec76618..c2c693eb65 100644 --- a/extensions/mssql/src/utils.ts +++ b/extensions/mssql/src/utils.ts @@ -14,13 +14,17 @@ import { IConfig, ServerProvider } from '@microsoft/ads-service-downloader'; import { env } from 'process'; const configTracingLevel = 'tracingLevel'; +const configPiiLogging = 'piiLogging'; const configLogRetentionMinutes = 'logRetentionMinutes'; const configLogFilesRemovalLimit = 'logFilesRemovalLimit'; const extensionConfigSectionName = 'mssql'; const configLogDebugInfo = 'logDebugInfo'; const parallelMessageProcessingConfig = 'parallelMessageProcessing'; +const enableSqlAuthenticationProviderConfig = 'enableSqlAuthenticationProvider'; const tableDesignerPreloadConfig = 'tableDesigner.preloadDatabaseModel'; +const azureExtensionConfigName = 'azure'; +const azureAuthenticationLibraryConfig = 'authenticationLibrary'; /** * * @returns Whether the current OS is linux or not @@ -62,7 +66,17 @@ export function removeOldLogFiles(logPath: string, prefix: string): JSON { } export function getConfiguration(config: string = extensionConfigSectionName): vscode.WorkspaceConfiguration { - return vscode.workspace.getConfiguration(extensionConfigSectionName); + return vscode.workspace.getConfiguration(config); +} +/** + * We need Azure core extension configuration for fetching Authentication Library setting in use. + * This is required for 'enableSqlAuthenticationProvider' to be enabled (as it applies to MSAL only). + * This can be removed in future when ADAL support is dropped. + * @param config Azure core extension configuration section name + * @returns Azure core extension config section + */ +export function getAzureCoreExtConfiguration(config: string = azureExtensionConfigName): vscode.WorkspaceConfiguration { + return vscode.workspace.getConfiguration(config); } export function getConfigLogFilesRemovalLimit(): number { @@ -105,6 +119,15 @@ export function getConfigTracingLevel(): TracingLevel { } } +export function getConfigPiiLogging(): boolean { + let config = getConfiguration(); + if (config) { + return config[configPiiLogging]; + } else { + return false; + } +} + export function getConfigPreloadDatabaseModel(): boolean { let config = getConfiguration(); if (config) { @@ -121,23 +144,47 @@ export function setConfigPreloadDatabaseModel(enable: boolean): void { } } -export async function getParallelMessageProcessingConfig(): Promise { +export function getParallelMessageProcessingConfig(): boolean { const config = getConfiguration(); if (!config) { return false; } const setting = config.inspect(parallelMessageProcessingConfig); - // For dev environment, we want to enable the feature by default unless it is set explicitely. - // Note: the quality property is not set for dev environment, we can use this to determine whether it is dev environment. return (azdata.env.quality === azdata.env.AppQuality.dev && setting.globalValue === undefined && setting.workspaceValue === undefined) ? true : config[parallelMessageProcessingConfig]; } +export function getAzureAuthenticationLibraryConfig(): string { + const config = getAzureCoreExtConfiguration(); + if (config) { + return config.has(azureAuthenticationLibraryConfig) + ? config.get(azureAuthenticationLibraryConfig) + : 'MSAL'; // default Auth library + } + else { + return 'MSAL'; + } +} + +export function getEnableSqlAuthenticationProviderConfig(): boolean { + const config = getConfiguration(); + if (config) { + return config.has(enableSqlAuthenticationProviderConfig) + ? config.get(enableSqlAuthenticationProviderConfig) + : false; // disabled by default + } + else { + return false; + } +} + export function getLogFileName(prefix: string, pid: number): string { return `${prefix}_${pid}.log`; } export function getCommonLaunchArgsAndCleanupOldLogFiles(logPath: string, fileName: string, executablePath: string): string[] { let launchArgs = []; + // Application Name determines app storage location or user data path. + launchArgs.push('--application-name', 'azuredatastudio'); launchArgs.push(`--locale`, vscode.env.language); launchArgs.push('--log-file'); @@ -151,6 +198,9 @@ export function getCommonLaunchArgsAndCleanupOldLogFiles(logPath: string, fileNa console.log(`Old log files deletion report: ${JSON.stringify(deletedLogFiles)}`); launchArgs.push('--tracing-level'); launchArgs.push(getConfigTracingLevel()); + if (getConfigPiiLogging()) { + launchArgs.push('--pii-logging'); + } // Always enable autoflush so that log entries are written immediately to disk, otherwise we can end up with partial logs launchArgs.push('--autoflush-log'); return launchArgs; diff --git a/src/sql/azdata.d.ts b/src/sql/azdata.d.ts index 58a969f117..4a9c93957b 100644 --- a/src/sql/azdata.d.ts +++ b/src/sql/azdata.d.ts @@ -2499,7 +2499,11 @@ declare module 'azdata' { /** * Power BI */ - PowerBi = 11 + PowerBi = 11, + /** + * Represents custom resource URIs as received from server endpoint. + */ + Custom = 12 } export interface DidChangeAccountsParams { diff --git a/src/sql/platform/connection/common/connectionManagement.ts b/src/sql/platform/connection/common/connectionManagement.ts index 4198d542b0..78827cf421 100644 --- a/src/sql/platform/connection/common/connectionManagement.ts +++ b/src/sql/platform/connection/common/connectionManagement.ts @@ -135,6 +135,12 @@ export interface IConnectionManagementService { */ findExistingConnection(connection: IConnectionProfile, purpose?: 'dashboard' | 'insights' | 'connection'): ConnectionProfile; + /** + * Fixes treeItem payload to consider defaultAuthenticationType and any other user settings. + * @param profile Connection profile as received from treeItem. + */ + fixProfile(profile?: azdata.IConnectionProfile): Promise; + /** * If there's already a connection for given profile and purpose, returns the ownerUri for the connection * otherwise tries to make a connection and returns the owner uri when connection is complete diff --git a/src/sql/platform/connection/test/common/testConnectionManagementService.ts b/src/sql/platform/connection/test/common/testConnectionManagementService.ts index 4fc15b01f5..41c24f36af 100644 --- a/src/sql/platform/connection/test/common/testConnectionManagementService.ts +++ b/src/sql/platform/connection/test/common/testConnectionManagementService.ts @@ -170,6 +170,10 @@ export class TestConnectionManagementService implements IConnectionManagementSer return undefined!; } + async fixProfile(profile?: IConnectionProfile): Promise { + return profile; + } + connect(connection: IConnectionProfile, uri: string, options?: IConnectionCompletionOptions, callbacks?: IConnectionCallbacks): Promise { return new Promise((resolve, reject) => { resolve({ connected: true, errorMessage: undefined!, errorCode: undefined!, messageDetails: undefined! }); diff --git a/src/sql/workbench/api/browser/mainThreadAccountManagement.ts b/src/sql/workbench/api/browser/mainThreadAccountManagement.ts index 97a9147e45..ddb572e812 100644 --- a/src/sql/workbench/api/browser/mainThreadAccountManagement.ts +++ b/src/sql/workbench/api/browser/mainThreadAccountManagement.ts @@ -72,7 +72,6 @@ export class MainThreadAccountManagement extends Disposable implements MainThrea clear(accountKey: azdata.AccountKey): Thenable { return self._proxy.$clear(handle, accountKey); }, - getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{}> { return self._proxy.$getSecurityToken(account, resource); }, diff --git a/src/sql/workbench/api/common/extHostAccountManagement.ts b/src/sql/workbench/api/common/extHostAccountManagement.ts index f9dc3df68c..33c6527a34 100644 --- a/src/sql/workbench/api/common/extHostAccountManagement.ts +++ b/src/sql/workbench/api/common/extHostAccountManagement.ts @@ -114,7 +114,6 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape { }); } - public get onDidChangeAccounts(): Event { return this._onDidChangeAccounts.event; } diff --git a/src/sql/workbench/api/common/sqlExtHostTypes.ts b/src/sql/workbench/api/common/sqlExtHostTypes.ts index 15e2702ba3..306f5dfe0f 100644 --- a/src/sql/workbench/api/common/sqlExtHostTypes.ts +++ b/src/sql/workbench/api/common/sqlExtHostTypes.ts @@ -474,7 +474,8 @@ export enum AzureResource { AzureLogAnalytics = 8, AzureStorage = 9, AzureKusto = 10, - PowerBi = 11 + PowerBi = 11, + Custom = 12 // Handles custom resource URIs as received from server endpoint. } export class TreeItem extends vsExtTypes.TreeItem { diff --git a/src/sql/workbench/contrib/backup/browser/backup.contribution.ts b/src/sql/workbench/contrib/backup/browser/backup.contribution.ts index 72c4ec9d39..74eff15f9e 100644 --- a/src/sql/workbench/contrib/backup/browser/backup.contribution.ts +++ b/src/sql/workbench/contrib/backup/browser/backup.contribution.ts @@ -20,6 +20,7 @@ import { ConnectionContextKey } from 'sql/workbench/services/connection/common/c import { ServerInfoContextKey } from 'sql/workbench/services/connection/common/serverInfoContextKey'; import { ServicesAccessor, IInstantiationService } from 'vs/platform/instantiation/common/instantiation'; import { DatabaseEngineEdition } from 'sql/workbench/api/common/sqlExtHostTypes'; +import { IConnectionManagementService } from 'sql/platform/connection/common/connectionManagement'; new BackupAction().registerTask(); @@ -30,7 +31,9 @@ CommandsRegistry.registerCommand({ handler: async (accessor, args: TreeViewItemHandleArg) => { if (args.$treeItem?.payload) { const commandService = accessor.get(ICommandService); - return commandService.executeCommand(BackupAction.ID, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + return commandService.executeCommand(BackupAction.ID, payload); } } }); @@ -69,9 +72,11 @@ MenuRegistry.appendMenuItem(MenuId.ObjectExplorerItemContext, { // dashboard explorer const ExplorerBackUpActionID = 'explorer.backup'; -CommandsRegistry.registerCommand(ExplorerBackUpActionID, (accessor, context: ManageActionContext) => { +CommandsRegistry.registerCommand(ExplorerBackUpActionID, async (accessor, context: ManageActionContext) => { const commandService = accessor.get(ICommandService); - return commandService.executeCommand(BackupAction.ID, context.profile); + const connectionService = accessor.get(IConnectionManagementService); + let profile = await connectionService.fixProfile(context.profile); + return commandService.executeCommand(BackupAction.ID, profile); }); MenuRegistry.appendMenuItem(MenuId.ExplorerWidgetContext, { diff --git a/src/sql/workbench/contrib/dashboard/browser/dashboardActions.ts b/src/sql/workbench/contrib/dashboard/browser/dashboardActions.ts index a0e5b89bbe..1344547e65 100644 --- a/src/sql/workbench/contrib/dashboard/browser/dashboardActions.ts +++ b/src/sql/workbench/contrib/dashboard/browser/dashboardActions.ts @@ -42,7 +42,8 @@ CommandsRegistry.registerCommand({ showConnectionDialogOnError: true, showFirewallRuleOnError: true }; - let profile = new ConnectionProfile(capabilitiesService, args.$treeItem.payload); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + let profile = new ConnectionProfile(capabilitiesService, payload); let uri = generateUri(profile, 'dashboard'); return connectionService.connect(new ConnectionProfile(capabilitiesService, args.$treeItem.payload), uri, options); } @@ -96,7 +97,8 @@ export class OEManageConnectionAction extends Action { if (actionContext instanceof ObjectExplorerActionsContext) { // Must use a real connection profile for this action due to lookup - connectionProfile = ConnectionProfile.fromIConnectionProfile(this._capabilitiesService, actionContext.connectionProfile); + let updatedIConnProfile = await this._connectionManagementService.fixProfile(actionContext.connectionProfile); + connectionProfile = ConnectionProfile.fromIConnectionProfile(this._capabilitiesService, updatedIConnProfile); if (!actionContext.isConnectionNode) { treeNode = await getTreeNode(actionContext, this._objectExplorerService); if (TreeUpdateUtils.isDatabaseNode(treeNode)) { diff --git a/src/sql/workbench/contrib/notebook/browser/notebook.contribution.ts b/src/sql/workbench/contrib/notebook/browser/notebook.contribution.ts index a9ffa51ddb..588be3c155 100644 --- a/src/sql/workbench/contrib/notebook/browser/notebook.contribution.ts +++ b/src/sql/workbench/contrib/notebook/browser/notebook.contribution.ts @@ -67,6 +67,7 @@ import { IViewDescriptorService, ViewContainerLocation } from 'vs/workbench/comm import { ToggleTabFocusModeAction } from 'vs/editor/contrib/toggleTabFocusMode/browser/toggleTabFocusMode'; import { ActiveEditorContext } from 'vs/workbench/common/contextkeys'; import { ILanguageService } from 'vs/editor/common/languages/language'; +import { IConnectionManagementService } from 'sql/platform/connection/common/connectionManagement'; Registry.as(EditorExtensions.EditorFactory) .registerEditorSerializer(FileNotebookInput.ID, FileNoteBookEditorSerializer); @@ -101,9 +102,11 @@ const DE_NEW_NOTEBOOK_COMMAND_ID = 'dataExplorer.newNotebook'; // New Notebook CommandsRegistry.registerCommand({ id: DE_NEW_NOTEBOOK_COMMAND_ID, - handler: (accessor, args: TreeViewItemHandleArg) => { + handler: async (accessor, args: TreeViewItemHandleArg) => { const instantiationService = accessor.get(IInstantiationService); - const connectedContext: ConnectedContext = { connectionProfile: args.$treeItem.payload }; + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + const connectedContext: ConnectedContext = { connectionProfile: payload }; return instantiationService.createInstance(NewNotebookAction, NewNotebookAction.ID, NewNotebookAction.LABEL).run({ connectionProfile: connectedContext.connectionProfile, isConnectionNode: false, nodeInfo: undefined }); } }); diff --git a/src/sql/workbench/contrib/objectExplorer/test/browser/connectionTreeActions.test.ts b/src/sql/workbench/contrib/objectExplorer/test/browser/connectionTreeActions.test.ts index 9fed886852..cbcf58588a 100644 --- a/src/sql/workbench/contrib/objectExplorer/test/browser/connectionTreeActions.test.ts +++ b/src/sql/workbench/contrib/objectExplorer/test/browser/connectionTreeActions.test.ts @@ -43,6 +43,7 @@ import { TestEditorService } from 'vs/workbench/test/browser/workbenchTestServic import { IEditorService } from 'vs/workbench/services/editor/common/editorService'; import { TestDialogService } from 'vs/platform/dialogs/test/common/testDialogService'; import { IDialogService } from 'vs/platform/dialogs/common/dialogs'; +import { IConnectionProfile } from 'sql/platform/connection/common/interfaces'; suite('SQL Connection Tree Action tests', () => { let errorMessageService: TypeMoq.Mock; @@ -61,7 +62,7 @@ suite('SQL Connection Tree Action tests', () => { errorMessageService.setup(x => x.showDialog(Severity.Error, TypeMoq.It.isAnyString(), TypeMoq.It.isAnyString())).returns(() => nothing); }); - function createConnectionManagementService(isConnectedReturnValue: boolean, profileToReturn: ConnectionProfile): TypeMoq.Mock { + function createConnectionManagementService(isConnectedReturnValue: boolean, profileToReturn: IConnectionProfile): TypeMoq.Mock { let connectionManagementService = TypeMoq.Mock.ofType(TestConnectionManagementService, TypeMoq.MockBehavior.Strict); connectionManagementService.callBase = true; connectionManagementService.setup(x => x.isConnected(undefined, TypeMoq.It.isAny())).returns(() => isConnectedReturnValue); @@ -77,6 +78,7 @@ suite('SQL Connection Tree Action tests', () => { connectionManagementService.setup(x => x.deleteConnectionGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(true)); connectionManagementService.setup(x => x.deleteConnection(TypeMoq.It.isAny())).returns(() => Promise.resolve(true)); connectionManagementService.setup(x => x.getConnectionProfile(TypeMoq.It.isAny())).returns(() => profileToReturn); + connectionManagementService.setup(x => x.fixProfile(TypeMoq.It.isAny())).returns(() => new Promise((resolve, reject) => resolve(profileToReturn))); connectionManagementService.setup(x => x.showEditConnectionDialog(TypeMoq.It.isAny())).returns(() => new Promise((resolve, reject) => resolve())); return connectionManagementService; } @@ -117,7 +119,7 @@ suite('SQL Connection Tree Action tests', () => { test('ManageConnectionAction - test if connect is called for manage action if not already connected', () => { let isConnectedReturnValue: boolean = false; - let connection: ConnectionProfile = new ConnectionProfile(capabilitiesService, { + let connection: IConnectionProfile = new ConnectionProfile(capabilitiesService, { connectionName: 'Test', savePassword: false, groupFullName: 'testGroup', @@ -197,7 +199,7 @@ suite('SQL Connection Tree Action tests', () => { viewsService); let actionContext = new ObjectExplorerActionsContext(); - actionContext.connectionProfile = connection.toIConnectionProfile(); + actionContext.connectionProfile = connection; actionContext.isConnectionNode = true; return manageConnectionAction.run(actionContext).then(() => { connectionManagementService.verify(x => x.connect(TypeMoq.It.isAny(), undefined, TypeMoq.It.isAny(), undefined), TypeMoq.Times.once()); diff --git a/src/sql/workbench/contrib/query/browser/queryActions.ts b/src/sql/workbench/contrib/query/browser/queryActions.ts index 9e97ff5384..b0c8bf9917 100644 --- a/src/sql/workbench/contrib/query/browser/queryActions.ts +++ b/src/sql/workbench/contrib/query/browser/queryActions.ts @@ -182,7 +182,8 @@ CommandsRegistry.registerCommand({ showConnectionDialogOnError: true, showFirewallRuleOnError: true }; - return connectionService.connect(new ConnectionProfile(capabilitiesService, args.$treeItem.payload), owner.uri, options); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + return connectionService.connect(new ConnectionProfile(capabilitiesService, payload), owner.uri, options); } return true; } diff --git a/src/sql/workbench/contrib/restore/browser/restore.contribution.ts b/src/sql/workbench/contrib/restore/browser/restore.contribution.ts index 371ea77e36..982764767a 100644 --- a/src/sql/workbench/contrib/restore/browser/restore.contribution.ts +++ b/src/sql/workbench/contrib/restore/browser/restore.contribution.ts @@ -19,6 +19,7 @@ import { ManageActionContext } from 'sql/workbench/browser/actions'; import { ItemContextKey } from 'sql/workbench/contrib/dashboard/browser/widgets/explorer/explorerContext'; import { ServerInfoContextKey } from 'sql/workbench/services/connection/common/serverInfoContextKey'; import { DatabaseEngineEdition } from 'sql/workbench/api/common/sqlExtHostTypes'; +import { IConnectionManagementService } from 'sql/platform/connection/common/connectionManagement'; new RestoreAction().registerTask(); @@ -29,7 +30,9 @@ CommandsRegistry.registerCommand({ handler: async (accessor, args: TreeViewItemHandleArg) => { if (args.$treeItem?.payload) { const commandService = accessor.get(ICommandService); - return commandService.executeCommand(RestoreAction.ID, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + return commandService.executeCommand(RestoreAction.ID, payload); } } }); @@ -51,9 +54,11 @@ MenuRegistry.appendMenuItem(MenuId.DataExplorerContext, { const OE_RESTORE_COMMAND_ID = 'objectExplorer.restore'; CommandsRegistry.registerCommand({ id: OE_RESTORE_COMMAND_ID, - handler: (accessor, args: ObjectExplorerActionsContext) => { + handler: async (accessor, args: ObjectExplorerActionsContext) => { const commandService = accessor.get(ICommandService); - return commandService.executeCommand(RestoreAction.ID, args.connectionProfile); + const connectionService = accessor.get(IConnectionManagementService); + let profile = await connectionService.fixProfile(args.connectionProfile); + return commandService.executeCommand(RestoreAction.ID, profile); } }); @@ -69,9 +74,11 @@ MenuRegistry.appendMenuItem(MenuId.ObjectExplorerItemContext, { }); const ExplorerRestoreActionID = 'explorer.restore'; -CommandsRegistry.registerCommand(ExplorerRestoreActionID, (accessor, context: ManageActionContext) => { +CommandsRegistry.registerCommand(ExplorerRestoreActionID, async (accessor, context: ManageActionContext) => { const commandService = accessor.get(ICommandService); - return commandService.executeCommand(RestoreAction.ID, context.profile); + const connectionService = accessor.get(IConnectionManagementService); + let profile = await connectionService.fixProfile(context.profile); + return commandService.executeCommand(RestoreAction.ID, profile); }); MenuRegistry.appendMenuItem(MenuId.ExplorerWidgetContext, { diff --git a/src/sql/workbench/contrib/scripting/browser/scriptingActions.ts b/src/sql/workbench/contrib/scripting/browser/scriptingActions.ts index bcba9a9ab0..a23a7b1200 100644 --- a/src/sql/workbench/contrib/scripting/browser/scriptingActions.ts +++ b/src/sql/workbench/contrib/scripting/browser/scriptingActions.ts @@ -49,7 +49,9 @@ CommandsRegistry.registerCommand({ const scriptingService = accessor.get(IScriptingService); const errorMessageService = accessor.get(IErrorMessageService); const progressService = accessor.get(IProgressService); - const profile = new ConnectionProfile(capabilitiesService, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + const profile = new ConnectionProfile(capabilitiesService, payload); const baseContext: BaseActionContext = { profile: profile, object: oeShimService.getNodeInfoForTreeItem(args.$treeItem)!.metadata @@ -73,7 +75,9 @@ CommandsRegistry.registerCommand({ const scriptingService = accessor.get(IScriptingService); const errorMessageService = accessor.get(IErrorMessageService); const progressService = accessor.get(IProgressService); - const profile = new ConnectionProfile(capabilitiesService, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + const profile = new ConnectionProfile(capabilitiesService, payload); const baseContext: BaseActionContext = { profile: profile, object: oeShimService.getNodeInfoForTreeItem(args.$treeItem)!.metadata @@ -97,7 +101,9 @@ CommandsRegistry.registerCommand({ const scriptingService = accessor.get(IScriptingService); const progressService = accessor.get(IProgressService); const errorMessageService = accessor.get(IErrorMessageService); - const profile = new ConnectionProfile(capabilitiesService, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + const profile = new ConnectionProfile(capabilitiesService, payload); const baseContext: BaseActionContext = { profile: profile, object: oeShimService.getNodeInfoForTreeItem(args.$treeItem)!.metadata @@ -121,7 +127,9 @@ CommandsRegistry.registerCommand({ const scriptingService = accessor.get(IScriptingService); const progressService = accessor.get(IProgressService); const errorMessageService = accessor.get(IErrorMessageService); - const profile = new ConnectionProfile(capabilitiesService, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + const profile = new ConnectionProfile(capabilitiesService, payload); const baseContext: BaseActionContext = { profile: profile, object: oeShimService.getNodeInfoForTreeItem(args.$treeItem)!.metadata @@ -145,7 +153,9 @@ CommandsRegistry.registerCommand({ const scriptingService = accessor.get(IScriptingService); const progressService = accessor.get(IProgressService); const errorMessageService = accessor.get(IErrorMessageService); - const profile = new ConnectionProfile(capabilitiesService, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + const profile = new ConnectionProfile(capabilitiesService, payload); const baseContext: BaseActionContext = { profile: profile, object: oeShimService.getNodeInfoForTreeItem(args.$treeItem)!.metadata @@ -169,7 +179,9 @@ CommandsRegistry.registerCommand({ const scriptingService = accessor.get(IScriptingService); const progressService = accessor.get(IProgressService); const errorMessageService = accessor.get(IErrorMessageService); - const profile = new ConnectionProfile(capabilitiesService, args.$treeItem.payload); + const connectionService = accessor.get(IConnectionManagementService); + let payload = await connectionService.fixProfile(args.$treeItem.payload); + const profile = new ConnectionProfile(capabilitiesService, payload); const baseContext: BaseActionContext = { profile: profile, object: oeShimService.getNodeInfoForTreeItem(args.$treeItem)!.metadata diff --git a/src/sql/workbench/services/connection/browser/connectionManagementService.ts b/src/sql/workbench/services/connection/browser/connectionManagementService.ts index b2c5775bfc..56aa4cb884 100644 --- a/src/sql/workbench/services/connection/browser/connectionManagementService.ts +++ b/src/sql/workbench/services/connection/browser/connectionManagementService.ts @@ -403,6 +403,24 @@ export class ConnectionManagementService extends Disposable implements IConnecti return this.tryConnect(connection, input, options); } + public async fixProfile(profile?: interfaces.IConnectionProfile): Promise { + if (profile) { + if (profile.authenticationType !== undefined && profile.authenticationType === '') { + // we need to set auth type here, because it's value is part of the session key + profile.authenticationType = this.getDefaultAuthenticationTypeId(profile.providerName); + } + + // If this is Azure MFA Authentication, fix username to azure Account user. Falls back to current user name. + // This is required, as by default, server login / administrator is the username. + if (profile.authenticationType === 'AzureMFA') { + let accounts = await this._accountManagementService?.getAccounts(); + profile.userName = accounts?.find(a => a.key.accountId === profile.azureAccount)?.displayInfo.displayName + ?? profile.userName; + } + } + return profile; + } + /** * If there's already a connection for given profile and purpose, returns the ownerUri for the connection * otherwise tries to make a connection and returns the owner uri when connection is complete diff --git a/src/sql/workbench/services/objectExplorer/browser/objectExplorerViewTreeShim.ts b/src/sql/workbench/services/objectExplorer/browser/objectExplorerViewTreeShim.ts index 790c52e3a4..0120a3d8b6 100644 --- a/src/sql/workbench/services/objectExplorer/browser/objectExplorerViewTreeShim.ts +++ b/src/sql/workbench/services/objectExplorer/browser/objectExplorerViewTreeShim.ts @@ -43,13 +43,13 @@ export class OEShimService extends Disposable implements IOEShimService { @IObjectExplorerService private oe: IObjectExplorerService, @IConnectionManagementService private cm: IConnectionManagementService, @ICapabilitiesService private capabilities: ICapabilitiesService, - @IConfigurationService private configurationService: IConfigurationService ) { super(); } private async createSession(viewId: string, providerId: string, node: ITreeItem): Promise { - let connProfile = new ConnectionProfile(this.capabilities, node.payload); + let payload = await this.cm.fixProfile(node.payload); + let connProfile = new ConnectionProfile(this.capabilities, payload); connProfile.saveProfile = false; if (this.cm.providerRegistered(providerId)) { connProfile = await this.connectOrPrompt(connProfile); @@ -119,9 +119,7 @@ export class OEShimService extends Disposable implements IOEShimService { public async getChildren(node: ITreeItem, viewId: string): Promise { if (node.payload) { - if (node.payload.authenticationType !== undefined && node.payload.authenticationType === '') { - node.payload.authenticationType = this.getDefaultAuthenticationType(this.configurationService); // we need to set auth type here, because it's value is part of the session key - } + node.payload = await this.cm.fixProfile(node.payload); if (node.sessionId === undefined) { node.sessionId = await this.createSession(viewId, node.childProvider!, node);