diff --git a/src/sql/parts/connection/common/connectionManagementService.ts b/src/sql/parts/connection/common/connectionManagementService.ts index f2d4e77155..2674f9d2d9 100644 --- a/src/sql/parts/connection/common/connectionManagementService.ts +++ b/src/sql/parts/connection/common/connectionManagementService.ts @@ -769,8 +769,19 @@ export class ConnectionManagementService extends Disposable implements IConnecti return false; } } - let tokens = await this._accountManagementService.getSecurityToken(account, AzureResource.Sql); - connection.options['azureAccountToken'] = Object.values(tokens)[0].token; + let tokensByTenant = await this._accountManagementService.getSecurityToken(account, AzureResource.Sql); + let token: string; + let tenantId = connection.azureTenantId; + if (tenantId && tokensByTenant[tenantId]) { + token = tokensByTenant[tenantId].token; + } else { + let tokens = Object.values(tokensByTenant); + if (tokens.length === 0) { + return false; + } + token = Object.values(tokensByTenant)[0].token; + } + connection.options['azureAccountToken'] = token; connection.options['password'] = ''; return true; } diff --git a/src/sql/parts/connection/common/connectionProfile.ts b/src/sql/parts/connection/common/connectionProfile.ts index b000e3b40e..acd1d4b0b1 100644 --- a/src/sql/parts/connection/common/connectionProfile.ts +++ b/src/sql/parts/connection/common/connectionProfile.ts @@ -42,6 +42,7 @@ export class ConnectionProfile extends ProviderConnectionInfo implements interfa this.savePassword = model.savePassword; this.saveProfile = model.saveProfile; this._id = model.id; + this.azureTenantId = model.azureTenantId; } else { //Default for a new connection this.savePassword = false; @@ -84,6 +85,14 @@ export class ConnectionProfile extends ProviderConnectionInfo implements interfa this._id = value; } + public get azureTenantId(): string { + return this.options['azureTenantId']; + } + + public set azureTenantId(value: string) { + this.options['azureTenantId'] = value; + } + public get groupFullName(): string { return this._groupName; } @@ -159,7 +168,8 @@ export class ConnectionProfile extends ProviderConnectionInfo implements interfa userName: this.userName, options: this.options, saveProfile: this.saveProfile, - id: this.id + id: this.id, + azureTenantId: this.azureTenantId }; return result; diff --git a/src/sql/parts/connection/connectionDialog/connectionWidget.ts b/src/sql/parts/connection/connectionDialog/connectionWidget.ts index d94fe340be..a8cd862f0a 100644 --- a/src/sql/parts/connection/connectionDialog/connectionWidget.ts +++ b/src/sql/parts/connection/connectionDialog/connectionWidget.ts @@ -52,9 +52,11 @@ export class ConnectionWidget { private _password: string; private _rememberPasswordCheckBox: Checkbox; private _azureAccountDropdown: SelectBox; + private _azureTenantDropdown: SelectBox; private _refreshCredentialsLinkBuilder: Builder; private _addAzureAccountMessage: string = localize('connectionWidget.AddAzureAccount', 'Add an account...'); private readonly _azureProviderId = 'azurePublicCloud'; + private _azureTenantId: string; private _azureAccountList: sqlops.Account[]; private _advancedButton: Button; private _callbacks: IConnectionComponentCallbacks; @@ -215,6 +217,12 @@ export class ConnectionWidget { let refreshCredentialsBuilder = DialogHelper.appendRow(this._tableContainer, '', 'connection-label', 'connection-input', 'azure-account-row refresh-credentials-link'); this._refreshCredentialsLinkBuilder = refreshCredentialsBuilder.a({ href: '#' }).text(localize('connectionWidget.refreshAzureCredentials', 'Refresh account credentials')); + // Azure tenant picker + let tenantLabel = localize('connection.azureTenantDropdownLabel', 'Azure AD tenant'); + let tenantDropdownBuilder = DialogHelper.appendRow(this._tableContainer, tenantLabel, 'connection-label', 'connection-input', 'azure-account-row azure-tenant-row'); + this._azureTenantDropdown = new SelectBox([], undefined, this._contextViewService, tenantDropdownBuilder.getContainer(), { ariaLabel: tenantLabel }); + DialogHelper.appendInputSelectBox(tenantDropdownBuilder, this._azureTenantDropdown); + // Database let databaseOption = this._optionsMaps[ConnectionOptionSpecialType.databaseName]; let databaseNameBuilder = DialogHelper.appendRow(this._tableContainer, databaseOption.displayName, 'connection-label', 'connection-input'); @@ -308,6 +316,13 @@ export class ConnectionWidget { })); } + if (this._azureTenantDropdown) { + this._toDispose.push(styler.attachSelectBoxStyler(this._azureTenantDropdown, this._themeService)); + this._toDispose.push(this._azureTenantDropdown.onDidSelect((selectInfo) => { + this.onAzureTenantSelected(selectInfo.index); + })); + } + if (this._refreshCredentialsLinkBuilder) { this._toDispose.push(this._refreshCredentialsLinkBuilder.on(DOM.EventType.CLICK, async () => { let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); @@ -426,7 +441,7 @@ export class ConnectionWidget { accountDropdownOptions.push(this._addAzureAccountMessage); this._azureAccountDropdown.setOptions(accountDropdownOptions); this._azureAccountDropdown.selectWithOptionName(oldSelection); - this.updateRefreshCredentialsLink(); + await this.onAzureAccountSelected(); } private async updateRefreshCredentialsLink(): Promise { @@ -441,7 +456,6 @@ export class ConnectionWidget { private async onAzureAccountSelected(): Promise { // Reset the dropdown's validation message if the old selection was not valid but the new one is this.validateAzureAccountSelection(false); - this._refreshCredentialsLinkBuilder.display('none'); // Open the add account dialog if needed, then select the added account if (this._azureAccountDropdown.value === this._addAzureAccountMessage) { @@ -461,6 +475,35 @@ export class ConnectionWidget { } this.updateRefreshCredentialsLink(); + + // Display the tenant select box if needed + const hideTenantsClassName = 'hide-azure-tenants'; + let selectedAccount = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); + if (selectedAccount && selectedAccount.properties.tenants && selectedAccount.properties.tenants.length > 1) { + // There are multiple tenants available so let the user select one + let options = selectedAccount.properties.tenants.map(tenant => tenant.displayName); + this._azureTenantDropdown.setOptions(options); + this._tableContainer.getContainer().classList.remove(hideTenantsClassName); + this.onAzureTenantSelected(0); + } else { + if (selectedAccount && selectedAccount.properties.tenants && selectedAccount.properties.tenants.length === 1) { + this._azureTenantId = selectedAccount.properties.tenants[0].id; + } else { + this._azureTenantId = undefined; + } + this._tableContainer.getContainer().classList.add(hideTenantsClassName); + } + } + + private onAzureTenantSelected(tenantIndex: number): void { + this._azureTenantId = undefined; + let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); + if (account && account.properties.tenants) { + let tenant = account.properties.tenants[tenantIndex]; + if (tenant) { + this._azureTenantId = tenant.id; + } + } } private serverNameChanged(serverName: string) { @@ -518,6 +561,7 @@ export class ConnectionWidget { this._passwordInputBox.value = connectionInfo.password ? Constants.passwordChars : ''; this._password = this.getModelValue(connectionInfo.password); this._saveProfile = connectionInfo.saveProfile; + this._azureTenantId = connectionInfo.azureTenantId; let groupName: string; if (this._saveProfile) { if (!connectionInfo.groupFullName) { @@ -551,6 +595,22 @@ export class ConnectionWidget { tableContainerElement.classList.add('hide-azure-accounts'); } + if (this.authType === AuthenticationType.AzureMFA) { + this.fillInAzureAccountOptions().then(async () => { + this._azureAccountDropdown.selectWithOptionName(this.getModelValue(connectionInfo.userName)); + await this.onAzureAccountSelected(); + let tenantId = connectionInfo.azureTenantId; + let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value); + if (account && account.properties.tenants.length > 1) { + let tenant = account.properties.tenants.find(tenant => tenant.id === tenantId); + if (tenant) { + this._azureTenantDropdown.selectWithOptionName(tenant.displayName); + } + this.onAzureTenantSelected(this._azureTenantDropdown.values.indexOf(this._azureTenantDropdown.value)); + } + }); + } + // Disable connect button if - // 1. Authentication type is SQL Login and no username is provided // 2. No server name is provided @@ -716,6 +776,9 @@ export class ConnectionWidget { model.saveProfile = true; model.groupId = this.findGroupId(model.groupFullName); } + if (this.authType === AuthenticationType.AzureMFA) { + model.azureTenantId = this._azureTenantId; + } } return validInputs; } diff --git a/src/sql/parts/connection/connectionDialog/media/connectionDialog.css b/src/sql/parts/connection/connectionDialog/media/connectionDialog.css index 09fcd8f1c4..92a47378aa 100644 --- a/src/sql/parts/connection/connectionDialog/media/connectionDialog.css +++ b/src/sql/parts/connection/connectionDialog/media/connectionDialog.css @@ -128,3 +128,7 @@ .hide-refresh-link .azure-account-row.refresh-credentials-link { display: none; } + +.hide-azure-tenants .azure-tenant-row { + display: none; +} diff --git a/src/sql/sqlops.d.ts b/src/sql/sqlops.d.ts index b78d05f794..2450549a2a 100644 --- a/src/sql/sqlops.d.ts +++ b/src/sql/sqlops.d.ts @@ -214,6 +214,7 @@ declare module 'sqlops' { providerName: string; saveProfile: boolean; id: string; + azureTenantId?: string; } /** diff --git a/src/sqltest/parts/connection/connectionManagementService.test.ts b/src/sqltest/parts/connection/connectionManagementService.test.ts index 3878518a53..8af840919d 100644 --- a/src/sqltest/parts/connection/connectionManagementService.test.ts +++ b/src/sqltest/parts/connection/connectionManagementService.test.ts @@ -881,4 +881,45 @@ suite('SQL ConnectionManagementService tests', () => { assert.equal(profileWithCredentials.userName, username); assert.equal(profileWithCredentials.options['azureAccountToken'], testToken); }); + + test('addSavedPassword fills in Azure access token for selected tenant', async () => { + // Set up a connection profile that uses Azure + let azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile); + azureConnectionProfile.authenticationType = 'AzureMFA'; + let username = 'testuser@microsoft.com'; + azureConnectionProfile.userName = username; + let servername = 'test-database.database.windows.net'; + azureConnectionProfile.serverName = servername; + let azureTenantId = 'testTenant'; + azureConnectionProfile.azureTenantId = azureTenantId; + + // Set up the account management service to return a token for the given user + accountManagementService.setup(x => x.getAccountsForProvider(TypeMoq.It.isAny())).returns(providerId => Promise.resolve([ + { + key: { + accountId: username, + providerId: providerId + }, + displayInfo: undefined, + isStale: false, + properties: undefined + } + ])); + let testToken = 'testToken'; + let returnedTokens = {}; + returnedTokens['azurePublicCloud'] = { token: 'badToken' }; + returnedTokens[azureTenantId] = { token: testToken }; + accountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(returnedTokens)); + connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({ + profile: profile, + savedCred: false + })); + + // If I call addSavedPassword + let profileWithCredentials = await connectionManagementService.addSavedPassword(azureConnectionProfile); + + // Then the returned profile has the account token set corresponding to the requested tenant + assert.equal(profileWithCredentials.userName, username); + assert.equal(profileWithCredentials.options['azureAccountToken'], testToken); + }); }); \ No newline at end of file