Enable Azure Active Directory MFA authentication (#3125)

This commit is contained in:
Matt Irvine
2018-11-27 11:13:47 -08:00
committed by GitHub
parent d646b4729b
commit cb72865dcc
33 changed files with 369 additions and 109 deletions

View File

@@ -69,8 +69,8 @@ export class AzureAccountProvider implements sqlops.AccountProvider {
return this._tokenCache.clear(); return this._tokenCache.clear();
} }
public getSecurityToken(account: AzureAccount): Thenable<AzureAccountSecurityTokenCollection> { public getSecurityToken(account: AzureAccount, resource: sqlops.AzureResource): Thenable<AzureAccountSecurityTokenCollection> {
return this.doIfInitialized(() => this.getAccessTokens(account)); return this.doIfInitialized(() => this.getAccessTokens(account, resource));
} }
public initialize(restoredAccounts: sqlops.Account[]): Thenable<sqlops.Account[]> { public initialize(restoredAccounts: sqlops.Account[]): Thenable<sqlops.Account[]> {
@@ -90,7 +90,7 @@ export class AzureAccountProvider implements sqlops.AccountProvider {
// Attempt to get fresh tokens. If this fails then the account is stale. // Attempt to get fresh tokens. If this fails then the account is stale.
// NOTE: Based on ADAL implementation, getting tokens should use the refresh token if necessary // NOTE: Based on ADAL implementation, getting tokens should use the refresh token if necessary
let task = this.getAccessTokens(account) let task = this.getAccessTokens(account, sqlops.AzureResource.ResourceManagement)
.then( .then(
() => { () => {
return account; return account;
@@ -161,9 +161,14 @@ export class AzureAccountProvider implements sqlops.AccountProvider {
: Promise.reject(localize('accountProviderNotInitialized', 'Account provider not initialized, cannot perform action')); : Promise.reject(localize('accountProviderNotInitialized', 'Account provider not initialized, cannot perform action'));
} }
private getAccessTokens(account: AzureAccount): Thenable<AzureAccountSecurityTokenCollection> { private getAccessTokens(account: AzureAccount, resource: sqlops.AzureResource): Thenable<AzureAccountSecurityTokenCollection> {
let self = this; let self = this;
const resourceIdMap = new Map<sqlops.AzureResource, string>([
[sqlops.AzureResource.ResourceManagement, self._metadata.settings.armResource.id],
[sqlops.AzureResource.Sql, self._metadata.settings.sqlResource.id]
]);
let accessTokenPromises: Thenable<void>[] = []; let accessTokenPromises: Thenable<void>[] = [];
let tokenCollection: AzureAccountSecurityTokenCollection = {}; let tokenCollection: AzureAccountSecurityTokenCollection = {};
for (let tenant of account.properties.tenants) { for (let tenant of account.properties.tenants) {
@@ -172,7 +177,7 @@ export class AzureAccountProvider implements sqlops.AccountProvider {
let context = new adal.AuthenticationContext(authorityUrl, null, self._tokenCache); let context = new adal.AuthenticationContext(authorityUrl, null, self._tokenCache);
context.acquireToken( context.acquireToken(
self._metadata.settings.armResource.id, resourceIdMap.get(resource),
tenant.userId, tenant.userId,
self._metadata.settings.clientId, self._metadata.settings.clientId,
(error: Error, response: adal.TokenResponse | adal.ErrorResponse) => { (error: Error, response: adal.TokenResponse | adal.ErrorResponse) => {

View File

@@ -81,6 +81,11 @@ export interface Settings {
*/ */
armResource?: Resource; armResource?: Resource;
/**
* Information that describes the SQL Azure resource
*/
sqlResource?: Resource;
/** /**
* A list of tenant IDs to authenticate against. If defined, then these IDs will be used * A list of tenant IDs to authenticate against. If defined, then these IDs will be used
* instead of querying the tenants endpoint of the armResource * instead of querying the tenants endpoint of the armResource

View File

@@ -27,6 +27,10 @@ const publicAzureSettings: ProviderSettings = {
id: 'https://management.core.windows.net/', id: 'https://management.core.windows.net/',
endpoint: 'https://management.azure.com' endpoint: 'https://management.azure.com'
}, },
sqlResource: {
id: 'https://database.windows.net/',
endpoint: 'https://database.windows.net'
},
redirectUri: 'http://localhost/redirect' redirectUri: 'http://localhost/redirect'
} }
} }

View File

@@ -212,8 +212,8 @@ export class ApiWrapper {
return sqlops.accounts.getAllAccounts(); return sqlops.accounts.getAllAccounts();
} }
public getSecurityToken(account: sqlops.Account): Thenable<{}> { public getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}> {
return sqlops.accounts.getSecurityToken(account); return sqlops.accounts.getSecurityToken(account, resource);
} }
public readonly onDidChangeAccounts = sqlops.accounts.onDidChangeAccounts; public readonly onDidChangeAccounts = sqlops.accounts.onDidChangeAccounts;

View File

@@ -6,7 +6,7 @@
'use strict'; 'use strict';
import { window, QuickPickItem } from 'vscode'; import { window, QuickPickItem } from 'vscode';
import { IConnectionProfile } from 'sqlops'; import * as sqlops from 'sqlops';
import { generateGuid } from './utils'; import { generateGuid } from './utils';
import { ApiWrapper } from '../apiWrapper'; import { ApiWrapper } from '../apiWrapper';
import { TreeNode } from '../treeNodes'; import { TreeNode } from '../treeNodes';
@@ -30,7 +30,7 @@ export function registerAzureResourceCommands(apiWrapper: ApiWrapper, tree: Azur
let subscriptions = await accountNode.getCachedSubscriptions(); let subscriptions = await accountNode.getCachedSubscriptions();
if (!subscriptions || subscriptions.length === 0) { if (!subscriptions || subscriptions.length === 0) {
const credentials = await servicePool.credentialService.getCredentials(accountNode.account); const credentials = await servicePool.credentialService.getCredentials(accountNode.account, sqlops.AzureResource.ResourceManagement);
subscriptions = await servicePool.subscriptionService.getSubscriptions(accountNode.account, credentials); subscriptions = await servicePool.subscriptionService.getSubscriptions(accountNode.account, credentials);
} }
@@ -71,7 +71,7 @@ export function registerAzureResourceCommands(apiWrapper: ApiWrapper, tree: Azur
}); });
apiWrapper.registerCommand('azureresource.connectsqldb', async (node?: TreeNode) => { apiWrapper.registerCommand('azureresource.connectsqldb', async (node?: TreeNode) => {
let connectionProfile: IConnectionProfile = { let connectionProfile: sqlops.IConnectionProfile = {
id: generateGuid(), id: generateGuid(),
connectionName: undefined, connectionName: undefined,
serverName: undefined, serverName: undefined,

View File

@@ -6,29 +6,29 @@
'use strict'; 'use strict';
import { ServiceClientCredentials } from 'ms-rest'; import { ServiceClientCredentials } from 'ms-rest';
import { Account, DidChangeAccountsParams } from 'sqlops'; import * as sqlops from 'sqlops';
import { Event } from 'vscode'; import { Event } from 'vscode';
import { AzureResourceSubscription, AzureResourceDatabaseServer, AzureResourceDatabase } from './models'; import { AzureResourceSubscription, AzureResourceDatabaseServer, AzureResourceDatabase } from './models';
export interface IAzureResourceAccountService { export interface IAzureResourceAccountService {
getAccounts(): Promise<Account[]>; getAccounts(): Promise<sqlops.Account[]>;
readonly onDidChangeAccounts: Event<DidChangeAccountsParams>; readonly onDidChangeAccounts: Event<sqlops.DidChangeAccountsParams>;
} }
export interface IAzureResourceCredentialService { export interface IAzureResourceCredentialService {
getCredentials(account: Account): Promise<ServiceClientCredentials[]>; getCredentials(account: sqlops.Account, resource: sqlops.AzureResource): Promise<ServiceClientCredentials[]>;
} }
export interface IAzureResourceSubscriptionService { export interface IAzureResourceSubscriptionService {
getSubscriptions(account: Account, credentials: ServiceClientCredentials[]): Promise<AzureResourceSubscription[]>; getSubscriptions(account: sqlops.Account, credentials: ServiceClientCredentials[]): Promise<AzureResourceSubscription[]>;
} }
export interface IAzureResourceSubscriptionFilterService { export interface IAzureResourceSubscriptionFilterService {
getSelectedSubscriptions(account: Account): Promise<AzureResourceSubscription[]>; getSelectedSubscriptions(account: sqlops.Account): Promise<AzureResourceSubscription[]>;
saveSelectedSubscriptions(account: Account, selectedSubscriptions: AzureResourceSubscription[]): Promise<void>; saveSelectedSubscriptions(account: sqlops.Account, selectedSubscriptions: AzureResourceSubscription[]): Promise<void>;
} }
export interface IAzureResourceDatabaseServerService { export interface IAzureResourceDatabaseServerService {

View File

@@ -5,7 +5,7 @@
'use strict'; 'use strict';
import { Account } from 'sqlops'; import * as sqlops from 'sqlops';
import { TokenCredentials, ServiceClientCredentials } from 'ms-rest'; import { TokenCredentials, ServiceClientCredentials } from 'ms-rest';
import { ApiWrapper } from '../../apiWrapper'; import { ApiWrapper } from '../../apiWrapper';
import * as nls from 'vscode-nls'; import * as nls from 'vscode-nls';
@@ -21,10 +21,10 @@ export class AzureResourceCredentialService implements IAzureResourceCredentialS
this._apiWrapper = apiWrapper; this._apiWrapper = apiWrapper;
} }
public async getCredentials(account: Account): Promise<ServiceClientCredentials[]> { public async getCredentials(account: sqlops.Account, resource: sqlops.AzureResource): Promise<ServiceClientCredentials[]> {
try { try {
let credentials: TokenCredentials[] = []; let credentials: TokenCredentials[] = [];
let tokens = await this._apiWrapper.getSecurityToken(account); let tokens = await this._apiWrapper.getSecurityToken(account, resource);
for (let tenant of account.properties.tenants) { for (let tenant of account.properties.tenants) {
let token = tokens[tenant.id].token; let token = tokens[tenant.id].token;

View File

@@ -5,7 +5,7 @@
'use strict'; 'use strict';
import { Account } from 'sqlops'; import * as sqlops from 'sqlops';
import { ServiceClientCredentials } from 'ms-rest'; import { ServiceClientCredentials } from 'ms-rest';
import { TreeNode } from '../../treeNodes'; import { TreeNode } from '../../treeNodes';
@@ -28,7 +28,7 @@ export abstract class AzureResourceTreeNodeBase extends TreeNode {
export abstract class AzureResourceContainerTreeNodeBase extends AzureResourceTreeNodeBase { export abstract class AzureResourceContainerTreeNodeBase extends AzureResourceTreeNodeBase {
public constructor( public constructor(
public readonly account: Account, public readonly account: sqlops.Account,
treeChangeHandler: IAzureResourceTreeChangeHandler, treeChangeHandler: IAzureResourceTreeChangeHandler,
parent: TreeNode parent: TreeNode
) { ) {
@@ -45,7 +45,7 @@ export abstract class AzureResourceContainerTreeNodeBase extends AzureResourceTr
protected async getCredentials(): Promise<ServiceClientCredentials[]> { protected async getCredentials(): Promise<ServiceClientCredentials[]> {
try { try {
return await this.servicePool.credentialService.getCredentials(this.account); return await this.servicePool.credentialService.getCredentials(this.account, sqlops.AzureResource.ResourceManagement);
} catch (error) { } catch (error) {
if (error instanceof AzureResourceCredentialError) { if (error instanceof AzureResourceCredentialError) {
this.servicePool.contextService.showErrorMessage(error.message); this.servicePool.contextService.showErrorMessage(error.message);

View File

@@ -87,7 +87,7 @@ describe('AzureResourceAccountTreeNode.info', function(): void {
mockServicePool.subscriptionService = mockSubscriptionService.object; mockServicePool.subscriptionService = mockSubscriptionService.object;
mockServicePool.subscriptionFilterService = mockSubscriptionFilterService.object; mockServicePool.subscriptionFilterService = mockSubscriptionFilterService.object;
mockCredentialService.setup((o) => o.getCredentials(mockAccount)).returns(() => Promise.resolve(mockCredentials)); mockCredentialService.setup((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement)).returns(() => Promise.resolve(mockCredentials));
mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache);
mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache.subscriptions[mockAccount.key.accountId] = mockSubscriptions); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache.subscriptions[mockAccount.key.accountId] = mockSubscriptions);
}); });
@@ -164,7 +164,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function(): void {
mockServicePool.subscriptionService = mockSubscriptionService.object; mockServicePool.subscriptionService = mockSubscriptionService.object;
mockServicePool.subscriptionFilterService = mockSubscriptionFilterService.object; mockServicePool.subscriptionFilterService = mockSubscriptionFilterService.object;
mockCredentialService.setup((o) => o.getCredentials(mockAccount)).returns(() => Promise.resolve(mockCredentials)); mockCredentialService.setup((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement)).returns(() => Promise.resolve(mockCredentials));
mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache);
mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache.subscriptions[mockAccount.key.accountId] = mockSubscriptions); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache.subscriptions[mockAccount.key.accountId] = mockSubscriptions);
}); });
@@ -177,7 +177,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function(): void {
const children = await accountTreeNode.getChildren(); const children = await accountTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.once()); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.once());
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredentials), TypeMoq.Times.once()); mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredentials), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once());
@@ -213,7 +213,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function(): void {
await accountTreeNode.getChildren(); await accountTreeNode.getChildren();
const children = await accountTreeNode.getChildren(); const children = await accountTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.exactly(1)); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.exactly(1));
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredentials), TypeMoq.Times.exactly(1)); mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredentials), TypeMoq.Times.exactly(1));
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(2)); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(2));
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.exactly(1)); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.exactly(1));
@@ -267,7 +267,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function(): void {
const children = await accountTreeNode.getChildren(); const children = await accountTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.once()); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.once());
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredentials), TypeMoq.Times.once()); mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredentials), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.never());

View File

@@ -118,7 +118,7 @@ describe('AzureResourceDatabaseContainerTreeNode.getChildren', function(): void
mockServicePool.credentialService = mockCredentialService.object; mockServicePool.credentialService = mockCredentialService.object;
mockServicePool.databaseService = mockDatabaseService.object; mockServicePool.databaseService = mockDatabaseService.object;
mockCredentialService.setup((o) => o.getCredentials(mockAccount)).returns(() => Promise.resolve(mockCredentials)); mockCredentialService.setup((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement)).returns(() => Promise.resolve(mockCredentials));
mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockDatabaseContainerCache); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockDatabaseContainerCache);
mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockDatabaseContainerCache.databases[mockSubscription.id] = mockDatabases); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockDatabaseContainerCache.databases[mockSubscription.id] = mockDatabases);
}); });
@@ -130,7 +130,7 @@ describe('AzureResourceDatabaseContainerTreeNode.getChildren', function(): void
const children = await databaseContainerTreeNode.getChildren(); const children = await databaseContainerTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.once()); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.once());
mockDatabaseService.verify((o) => o.getDatabases(mockSubscription, mockCredentials), TypeMoq.Times.once()); mockDatabaseService.verify((o) => o.getDatabases(mockSubscription, mockCredentials), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once());
@@ -160,7 +160,7 @@ describe('AzureResourceDatabaseContainerTreeNode.getChildren', function(): void
await databaseContainerTreeNode.getChildren(); await databaseContainerTreeNode.getChildren();
const children = await databaseContainerTreeNode.getChildren(); const children = await databaseContainerTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.exactly(1)); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.exactly(1));
mockDatabaseService.verify((o) => o.getDatabases(mockSubscription, mockCredentials), TypeMoq.Times.exactly(1)); mockDatabaseService.verify((o) => o.getDatabases(mockSubscription, mockCredentials), TypeMoq.Times.exactly(1));
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(2)); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(2));
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.exactly(1)); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.exactly(1));
@@ -193,7 +193,7 @@ describe('AzureResourceDatabaseContainerTreeNode.getChildren', function(): void
const databaseContainerTreeNode = new AzureResourceDatabaseContainerTreeNode(mockSubscription, mockAccount, mockTreeChangeHandler.object, undefined); const databaseContainerTreeNode = new AzureResourceDatabaseContainerTreeNode(mockSubscription, mockAccount, mockTreeChangeHandler.object, undefined);
const children = await databaseContainerTreeNode.getChildren(); const children = await databaseContainerTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.once()); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.once());
mockDatabaseService.verify((o) => o.getDatabases(mockSubscription, mockCredentials), TypeMoq.Times.once()); mockDatabaseService.verify((o) => o.getDatabases(mockSubscription, mockCredentials), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.never());

View File

@@ -118,7 +118,7 @@ describe('AzureResourceDatabaseServerContainerTreeNode.getChildren', function():
mockServicePool.credentialService = mockCredentialService.object; mockServicePool.credentialService = mockCredentialService.object;
mockServicePool.databaseServerService = mockDatabaseServerService.object; mockServicePool.databaseServerService = mockDatabaseServerService.object;
mockCredentialService.setup((o) => o.getCredentials(mockAccount)).returns(() => Promise.resolve(mockCredentials)); mockCredentialService.setup((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement)).returns(() => Promise.resolve(mockCredentials));
mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockDatabaseServerContainerCache); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockDatabaseServerContainerCache);
mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockDatabaseServerContainerCache.databaseServers[mockSubscription.id] = mockDatabaseServers); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockDatabaseServerContainerCache.databaseServers[mockSubscription.id] = mockDatabaseServers);
}); });
@@ -130,7 +130,7 @@ describe('AzureResourceDatabaseServerContainerTreeNode.getChildren', function():
const children = await databaseServerContainerTreeNode.getChildren(); const children = await databaseServerContainerTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.once()); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.once());
mockDatabaseServerService.verify((o) => o.getDatabaseServers(mockSubscription, mockCredentials), TypeMoq.Times.once()); mockDatabaseServerService.verify((o) => o.getDatabaseServers(mockSubscription, mockCredentials), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once());
@@ -160,7 +160,7 @@ describe('AzureResourceDatabaseServerContainerTreeNode.getChildren', function():
await databaseServerContainerTreeNode.getChildren(); await databaseServerContainerTreeNode.getChildren();
const children = await databaseServerContainerTreeNode.getChildren(); const children = await databaseServerContainerTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.exactly(1)); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.exactly(1));
mockDatabaseServerService.verify((o) => o.getDatabaseServers(mockSubscription, mockCredentials), TypeMoq.Times.exactly(1)); mockDatabaseServerService.verify((o) => o.getDatabaseServers(mockSubscription, mockCredentials), TypeMoq.Times.exactly(1));
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(2)); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(2));
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.exactly(1)); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.exactly(1));
@@ -193,7 +193,7 @@ describe('AzureResourceDatabaseServerContainerTreeNode.getChildren', function():
const databaseServerContainerTreeNode = new AzureResourceDatabaseServerContainerTreeNode(mockSubscription, mockAccount, mockTreeChangeHandler.object, undefined); const databaseServerContainerTreeNode = new AzureResourceDatabaseServerContainerTreeNode(mockSubscription, mockAccount, mockTreeChangeHandler.object, undefined);
const children = await databaseServerContainerTreeNode.getChildren(); const children = await databaseServerContainerTreeNode.getChildren();
mockCredentialService.verify((o) => o.getCredentials(mockAccount), TypeMoq.Times.once()); mockCredentialService.verify((o) => o.getCredentials(mockAccount, sqlops.AzureResource.ResourceManagement), TypeMoq.Times.once());
mockDatabaseServerService.verify((o) => o.getDatabaseServers(mockSubscription, mockCredentials), TypeMoq.Times.once()); mockDatabaseServerService.verify((o) => o.getDatabaseServers(mockSubscription, mockCredentials), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.never());

View File

@@ -285,6 +285,10 @@
{ {
"displayName": "Windows Authentication", "displayName": "Windows Authentication",
"name": "Integrated" "name": "Integrated"
},
{
"displayName": "Azure Active Directory - Universal with MFA support",
"name": "AzureMFA"
} }
], ],
"isRequired": true, "isRequired": true,

View File

@@ -12,9 +12,10 @@ import * as types from 'vs/base/common/types';
import * as sqlops from 'sqlops'; import * as sqlops from 'sqlops';
export function appendRow(container: Builder, label: string, labelClass: string, cellContainerClass: string): Builder { export function appendRow(container: Builder, label: string, labelClass: string, cellContainerClass: string, rowContainerClass?: string): Builder {
let cellContainer: Builder; let cellContainer: Builder;
container.element('tr', {}, (rowContainer) => { let rowAttributes = rowContainerClass ? { class: rowContainerClass } : {};
container.element('tr', rowAttributes, (rowContainer) => {
rowContainer.element('td', { class: labelClass }, (labelCellContainer) => { rowContainer.element('td', { class: labelClass }, (labelCellContainer) => {
labelCellContainer.div({}, (labelContainer) => { labelCellContainer.div({}, (labelContainer) => {
labelContainer.text(label); labelContainer.text(label);

View File

@@ -75,7 +75,12 @@ export class SelectBox extends vsSelectBox {
// explicitly set the accessible role so that the screen readers can read the control type properly // explicitly set the accessible role so that the screen readers can read the control type properly
this.selectElement.setAttribute('role', 'combobox'); this.selectElement.setAttribute('role', 'combobox');
this._selectBoxOptions = selectBoxOptions; this._selectBoxOptions = selectBoxOptions;
var focusTracker = dom.trackFocus(this.selectElement);
this._register(focusTracker);
this._register(focusTracker.onDidBlur(() => this._hideMessage()));
this._register(focusTracker.onDidFocus(() => this._showMessage()));
} }
public style(styles: ISelectBoxStyles): void { public style(styles: ISelectBoxStyles): void {
@@ -142,6 +147,10 @@ export class SelectBox extends vsSelectBox {
this.applyStyles(); this.applyStyles();
} }
public hasFocus(): boolean {
return document.activeElement === this.selectElement;
}
public showMessage(message: IMessage): void { public showMessage(message: IMessage): void {
this.message = message; this.message = message;
@@ -163,7 +172,9 @@ export class SelectBox extends vsSelectBox {
aria.alert(alertText); aria.alert(alertText);
this._showMessage(); if (this.hasFocus()) {
this._showMessage();
}
} }
public _showMessage(): void { public _showMessage(): void {

View File

@@ -12,7 +12,7 @@ import * as sqlops from 'sqlops';
import { IConnectionProfile } from 'sql/parts/connection/common/interfaces'; import { IConnectionProfile } from 'sql/parts/connection/common/interfaces';
import { IErrorMessageService } from 'sql/parts/connection/common/connectionManagement'; import { IErrorMessageService } from 'sql/parts/connection/common/connectionManagement';
import { FirewallRuleDialog } from 'sql/parts/accountManagement/firewallRuleDialog/firewallRuleDialog'; import { FirewallRuleDialog } from 'sql/parts/accountManagement/firewallRuleDialog/firewallRuleDialog';
import { IAccountManagementService } from 'sql/services/accountManagement/interfaces'; import { IAccountManagementService, AzureResource } from 'sql/services/accountManagement/interfaces';
import { IResourceProviderService } from 'sql/parts/accountManagement/common/interfaces'; import { IResourceProviderService } from 'sql/parts/accountManagement/common/interfaces';
import { Deferred } from 'sql/base/common/promise'; import { Deferred } from 'sql/base/common/promise';
@@ -61,7 +61,7 @@ export class FirewallRuleDialogController {
private handleOnCreateFirewallRule(): void { private handleOnCreateFirewallRule(): void {
let resourceProviderId = this._resourceProviderId; let resourceProviderId = this._resourceProviderId;
this._accountManagementService.getSecurityToken(this._firewallRuleDialog.viewModel.selectedAccount).then(tokenMappings => { this._accountManagementService.getSecurityToken(this._firewallRuleDialog.viewModel.selectedAccount, AzureResource.ResourceManagement).then(tokenMappings => {
let firewallRuleInfo: sqlops.FirewallRuleInfo = { let firewallRuleInfo: sqlops.FirewallRuleInfo = {
startIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.fromSubnetIPRange, startIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.fromSubnetIPRange,
endIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.toSubnetIPRange, endIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.toSubnetIPRange,

View File

@@ -34,12 +34,13 @@ import { Deferred } from 'sql/base/common/promise';
import { ConnectionOptionSpecialType } from 'sql/workbench/api/common/sqlExtHostTypes'; import { ConnectionOptionSpecialType } from 'sql/workbench/api/common/sqlExtHostTypes';
import { values } from 'sql/base/common/objects'; import { values } from 'sql/base/common/objects';
import { ConnectionProviderProperties, IConnectionProviderRegistry, Extensions as ConnectionProviderExtensions } from 'sql/workbench/parts/connection/common/connectionProviderExtension'; import { ConnectionProviderProperties, IConnectionProviderRegistry, Extensions as ConnectionProviderExtensions } from 'sql/workbench/parts/connection/common/connectionProviderExtension';
import { IAccountManagementService, AzureResource } from 'sql/services/accountManagement/interfaces';
import * as sqlops from 'sqlops'; import * as sqlops from 'sqlops';
import * as nls from 'vs/nls'; import * as nls from 'vs/nls';
import * as errors from 'vs/base/common/errors'; import * as errors from 'vs/base/common/errors';
import { IDisposable, dispose, Disposable } from 'vs/base/common/lifecycle'; import { Disposable } from 'vs/base/common/lifecycle';
import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation'; import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation';
import { IEditorService, ACTIVE_GROUP } from 'vs/workbench/services/editor/common/editorService'; import { IEditorService, ACTIVE_GROUP } from 'vs/workbench/services/editor/common/editorService';
import * as platform from 'vs/platform/registry/common/platform'; import * as platform from 'vs/platform/registry/common/platform';
@@ -58,7 +59,6 @@ import * as statusbar from 'vs/workbench/browser/parts/statusbar/statusbar';
import { IViewletService } from 'vs/workbench/services/viewlet/browser/viewlet'; import { IViewletService } from 'vs/workbench/services/viewlet/browser/viewlet';
import { IStatusbarService } from 'vs/platform/statusbar/common/statusbar'; import { IStatusbarService } from 'vs/platform/statusbar/common/statusbar';
import { ICommandService } from 'vs/platform/commands/common/commands'; import { ICommandService } from 'vs/platform/commands/common/commands';
import { EditorGroup } from 'vs/workbench/common/editor/editorGroup';
export class ConnectionManagementService extends Disposable implements IConnectionManagementService { export class ConnectionManagementService extends Disposable implements IConnectionManagementService {
@@ -100,7 +100,8 @@ export class ConnectionManagementService extends Disposable implements IConnecti
@IStatusbarService private _statusBarService: IStatusbarService, @IStatusbarService private _statusBarService: IStatusbarService,
@IResourceProviderService private _resourceProviderService: IResourceProviderService, @IResourceProviderService private _resourceProviderService: IResourceProviderService,
@IViewletService private _viewletService: IViewletService, @IViewletService private _viewletService: IViewletService,
@IAngularEventingService private _angularEventing: IAngularEventingService @IAngularEventingService private _angularEventing: IAngularEventingService,
@IAccountManagementService private _accountManagementService: IAccountManagementService
) { ) {
super(); super();
if (this._instantiationService) { if (this._instantiationService) {
@@ -248,7 +249,8 @@ export class ConnectionManagementService extends Disposable implements IConnecti
* Load the password for the profile * Load the password for the profile
* @param connectionProfile Connection Profile * @param connectionProfile Connection Profile
*/ */
public addSavedPassword(connectionProfile: IConnectionProfile): Promise<IConnectionProfile> { public async addSavedPassword(connectionProfile: IConnectionProfile): Promise<IConnectionProfile> {
await this.fillInAzureTokenIfNeeded(connectionProfile);
return this._connectionStore.addSavedPassword(connectionProfile).then(result => result.profile); return this._connectionStore.addSavedPassword(connectionProfile).then(result => result.profile);
} }
@@ -274,7 +276,7 @@ export class ConnectionManagementService extends Disposable implements IConnecti
let self = this; let self = this;
return new Promise<IConnectionResult>((resolve, reject) => { return new Promise<IConnectionResult>((resolve, reject) => {
// Load the password if it's not already loaded // Load the password if it's not already loaded
self._connectionStore.addSavedPassword(connection).then(result => { self._connectionStore.addSavedPassword(connection).then(async result => {
let newConnection = result.profile; let newConnection = result.profile;
let foundPassword = result.savedCred; let foundPassword = result.savedCred;
@@ -286,8 +288,12 @@ export class ConnectionManagementService extends Disposable implements IConnecti
foundPassword = true; foundPassword = true;
} }
} }
// Fill in the Azure account token if needed and open the connection dialog if it fails
let tokenFillSuccess = await self.fillInAzureTokenIfNeeded(newConnection);
// If the password is required and still not loaded show the dialog // If the password is required and still not loaded show the dialog
if (!foundPassword && self._connectionStore.isPasswordRequired(newConnection) && !newConnection.password) { if ((!foundPassword && self._connectionStore.isPasswordRequired(newConnection) && !newConnection.password) || !tokenFillSuccess) {
resolve(self.showConnectionDialogOnError(connection, owner, { connected: false, errorMessage: undefined, callStack: undefined, errorCode: undefined }, options)); resolve(self.showConnectionDialogOnError(connection, owner, { connected: false, errorMessage: undefined, callStack: undefined, errorCode: undefined }, options));
} else { } else {
// Try to connect // Try to connect
@@ -449,10 +455,14 @@ export class ConnectionManagementService extends Disposable implements IConnecti
showFirewallRuleOnError: true showFirewallRuleOnError: true
}; };
} }
return new Promise<IConnectionResult>((resolve, reject) => { return new Promise<IConnectionResult>(async (resolve, reject) => {
if (callbacks.onConnectStart) { if (callbacks.onConnectStart) {
callbacks.onConnectStart(); callbacks.onConnectStart();
} }
let tokenFillSuccess = await this.fillInAzureTokenIfNeeded(connection);
if (!tokenFillSuccess) {
throw new Error(nls.localize('connection.noAzureAccount', 'Failed to get Azure account token for connection'));
}
this.createNewConnection(uri, connection).then(connectionResult => { this.createNewConnection(uri, connection).then(connectionResult => {
if (connectionResult && connectionResult.connected) { if (connectionResult && connectionResult.connected) {
if (callbacks.onConnectSuccess) { if (callbacks.onConnectSuccess) {
@@ -743,8 +753,33 @@ export class ConnectionManagementService extends Disposable implements IConnecti
} }
} }
private async fillInAzureTokenIfNeeded(connection: IConnectionProfile): Promise<boolean> {
if (connection.authenticationType !== Constants.azureMFA || connection.options['azureAccountToken']) {
return true;
}
let accounts = await this._accountManagementService.getAccountsForProvider('azurePublicCloud');
if (accounts && accounts.length > 0) {
let account = accounts.find(account => account.key.accountId === connection.userName);
if (account) {
if (account.isStale) {
try {
account = await this._accountManagementService.refreshAccount(account);
} catch {
// refreshAccount throws an error if the user cancels the dialog
return false;
}
}
let tokens = await this._accountManagementService.getSecurityToken(account, AzureResource.Sql);
connection.options['azureAccountToken'] = Object.values(tokens)[0].token;
connection.options['password'] = '';
return true;
}
}
return false;
}
// Request Senders // Request Senders
private sendConnectRequest(connection: IConnectionProfile, uri: string): Thenable<boolean> { private async sendConnectRequest(connection: IConnectionProfile, uri: string): Promise<boolean> {
let connectionInfo = Object.assign({}, { let connectionInfo = Object.assign({}, {
options: connection.options options: connection.options
}); });

View File

@@ -33,4 +33,5 @@ export const passwordChars = '***************';
/* authentication types */ /* authentication types */
export const sqlLogin = 'SqlLogin'; export const sqlLogin = 'SqlLogin';
export const integrated = 'Integrated'; export const integrated = 'Integrated';
export const azureMFA = 'AzureMFA';

View File

@@ -22,6 +22,8 @@ import { Dropdown } from 'sql/base/browser/ui/editableDropdown/dropdown';
import { IConnectionManagementService } from 'sql/parts/connection/common/connectionManagement'; import { IConnectionManagementService } from 'sql/parts/connection/common/connectionManagement';
import { ICapabilitiesService } from 'sql/services/capabilities/capabilitiesService'; import { ICapabilitiesService } from 'sql/services/capabilities/capabilitiesService';
import { ConnectionProfile } from '../common/connectionProfile'; import { ConnectionProfile } from '../common/connectionProfile';
import * as styler from 'sql/common/theme/styler';
import { IAccountManagementService } from 'sql/services/accountManagement/interfaces';
import * as sqlops from 'sqlops'; import * as sqlops from 'sqlops';
@@ -30,7 +32,6 @@ import { IContextViewService } from 'vs/platform/contextview/browser/contextView
import { localize } from 'vs/nls'; import { localize } from 'vs/nls';
import * as DOM from 'vs/base/browser/dom'; import * as DOM from 'vs/base/browser/dom';
import { IThemeService } from 'vs/platform/theme/common/themeService'; import { IThemeService } from 'vs/platform/theme/common/themeService';
import * as styler from 'vs/platform/theme/common/styler';
import { OS, OperatingSystem } from 'vs/base/common/platform'; import { OS, OperatingSystem } from 'vs/base/common/platform';
import { Builder, $ } from 'vs/base/browser/builder'; import { Builder, $ } from 'vs/base/browser/builder';
import { MessageType } from 'vs/base/browser/ui/inputbox/inputBox'; import { MessageType } from 'vs/base/browser/ui/inputbox/inputBox';
@@ -50,6 +51,11 @@ export class ConnectionWidget {
private _passwordInputBox: InputBox; private _passwordInputBox: InputBox;
private _password: string; private _password: string;
private _rememberPasswordCheckBox: Checkbox; private _rememberPasswordCheckBox: Checkbox;
private _azureAccountDropdown: SelectBox;
private _refreshCredentialsLinkBuilder: Builder;
private _addAzureAccountMessage: string = localize('connectionWidget.AddAzureAccount', 'Add an account...');
private readonly _azureProviderId = 'azurePublicCloud';
private _azureAccountList: sqlops.Account[];
private _advancedButton: Button; private _advancedButton: Button;
private _callbacks: IConnectionComponentCallbacks; private _callbacks: IConnectionComponentCallbacks;
private _authTypeSelectBox: SelectBox; private _authTypeSelectBox: SelectBox;
@@ -59,7 +65,7 @@ export class ConnectionWidget {
private _focusedBeforeHandleOnConnection: HTMLElement; private _focusedBeforeHandleOnConnection: HTMLElement;
private _providerName: string; private _providerName: string;
private _authTypeMap: { [providerName: string]: AuthenticationType[] } = { private _authTypeMap: { [providerName: string]: AuthenticationType[] } = {
[Constants.mssqlProviderName]: [new AuthenticationType(Constants.integrated, false), new AuthenticationType(Constants.sqlLogin, true)] [Constants.mssqlProviderName]: [AuthenticationType.SqlLogin, AuthenticationType.Integrated, AuthenticationType.AzureMFA]
}; };
private _saveProfile: boolean; private _saveProfile: boolean;
private _databaseDropdownExpanded: boolean = false; private _databaseDropdownExpanded: boolean = false;
@@ -96,7 +102,8 @@ export class ConnectionWidget {
@IConnectionManagementService private _connectionManagementService: IConnectionManagementService, @IConnectionManagementService private _connectionManagementService: IConnectionManagementService,
@ICapabilitiesService private _capabilitiesService: ICapabilitiesService, @ICapabilitiesService private _capabilitiesService: ICapabilitiesService,
@IClipboardService private _clipboardService: IClipboardService, @IClipboardService private _clipboardService: IClipboardService,
@IConfigurationService private _configurationService: IConfigurationService @IConfigurationService private _configurationService: IConfigurationService,
@IAccountManagementService private _accountManagementService: IAccountManagementService
) { ) {
this._callbacks = callbacks; this._callbacks = callbacks;
this._toDispose = []; this._toDispose = [];
@@ -109,9 +116,9 @@ export class ConnectionWidget {
var authTypeOption = this._optionsMaps[ConnectionOptionSpecialType.authType]; var authTypeOption = this._optionsMaps[ConnectionOptionSpecialType.authType];
if (authTypeOption) { if (authTypeOption) {
if (OS === OperatingSystem.Windows) { if (OS === OperatingSystem.Windows) {
authTypeOption.defaultValue = this.getAuthTypeDisplayName(Constants.integrated); authTypeOption.defaultValue = this.getAuthTypeDisplayName(AuthenticationType.Integrated);
} else { } else {
authTypeOption.defaultValue = this.getAuthTypeDisplayName(Constants.sqlLogin); authTypeOption.defaultValue = this.getAuthTypeDisplayName(AuthenticationType.SqlLogin);
} }
this._authTypeSelectBox = new SelectBox(authTypeOption.categoryValues.map(c => c.displayName), authTypeOption.defaultValue, this._contextViewService, undefined, { ariaLabel: authTypeOption.displayName }); this._authTypeSelectBox = new SelectBox(authTypeOption.categoryValues.map(c => c.displayName), authTypeOption.defaultValue, this._contextViewService, undefined, { ariaLabel: authTypeOption.displayName });
} }
@@ -182,7 +189,7 @@ export class ConnectionWidget {
// Username // Username
let self = this; let self = this;
let userNameOption = this._optionsMaps[ConnectionOptionSpecialType.userName]; let userNameOption = this._optionsMaps[ConnectionOptionSpecialType.userName];
let userNameBuilder = DialogHelper.appendRow(this._tableContainer, userNameOption.displayName, 'connection-label', 'connection-input'); let userNameBuilder = DialogHelper.appendRow(this._tableContainer, userNameOption.displayName, 'connection-label', 'connection-input', 'username-password-row');
this._userNameInputBox = new InputBox(userNameBuilder.getHTMLElement(), this._contextViewService, { this._userNameInputBox = new InputBox(userNameBuilder.getHTMLElement(), this._contextViewService, {
validationOptions: { validationOptions: {
validation: (value: string) => self.validateUsername(value, userNameOption.isRequired) ? ({ type: MessageType.ERROR, content: localize('connectionWidget.missingRequireField', '{0} is required.', userNameOption.displayName) }) : null validation: (value: string) => self.validateUsername(value, userNameOption.isRequired) ? ({ type: MessageType.ERROR, content: localize('connectionWidget.missingRequireField', '{0} is required.', userNameOption.displayName) }) : null
@@ -191,14 +198,22 @@ export class ConnectionWidget {
}); });
// Password // Password
let passwordOption = this._optionsMaps[ConnectionOptionSpecialType.password]; let passwordOption = this._optionsMaps[ConnectionOptionSpecialType.password];
let passwordBuilder = DialogHelper.appendRow(this._tableContainer, passwordOption.displayName, 'connection-label', 'connection-input'); let passwordBuilder = DialogHelper.appendRow(this._tableContainer, passwordOption.displayName, 'connection-label', 'connection-input', 'username-password-row');
this._passwordInputBox = new InputBox(passwordBuilder.getHTMLElement(), this._contextViewService, { ariaLabel: passwordOption.displayName }); this._passwordInputBox = new InputBox(passwordBuilder.getHTMLElement(), this._contextViewService, { ariaLabel: passwordOption.displayName });
this._passwordInputBox.inputElement.type = 'password'; this._passwordInputBox.inputElement.type = 'password';
this._password = ''; this._password = '';
// Remember password // Remember password
let rememberPasswordLabel = localize('rememberPassword', 'Remember password'); let rememberPasswordLabel = localize('rememberPassword', 'Remember password');
this._rememberPasswordCheckBox = this.appendCheckbox(this._tableContainer, rememberPasswordLabel, 'connection-checkbox', 'connection-input', false); this._rememberPasswordCheckBox = this.appendCheckbox(this._tableContainer, rememberPasswordLabel, 'connection-checkbox', 'connection-input', 'username-password-row', false);
// Azure account picker
let accountLabel = localize('connection.azureAccountDropdownLabel', 'Account');
let accountDropdownBuilder = DialogHelper.appendRow(this._tableContainer, accountLabel, 'connection-label', 'connection-input', 'azure-account-row');
this._azureAccountDropdown = new SelectBox([], undefined, this._contextViewService, accountDropdownBuilder.getContainer(), { ariaLabel: accountLabel });
DialogHelper.appendInputSelectBox(accountDropdownBuilder, this._azureAccountDropdown);
let refreshCredentialsBuilder = DialogHelper.appendRow(this._tableContainer, '', 'connection-label', 'connection-input', 'azure-account-row refresh-credentials-link');
this._refreshCredentialsLinkBuilder = refreshCredentialsBuilder.a({ href: '#' }).text(localize('connectionWidget.refreshAzureCredentials', 'Refresh account credentials'));
// Database // Database
let databaseOption = this._optionsMaps[ConnectionOptionSpecialType.databaseName]; let databaseOption = this._optionsMaps[ConnectionOptionSpecialType.databaseName];
@@ -228,7 +243,7 @@ export class ConnectionWidget {
private validateUsername(value: string, isOptionRequired: boolean): boolean { private validateUsername(value: string, isOptionRequired: boolean): boolean {
let currentAuthType = this._authTypeSelectBox ? this.getMatchingAuthType(this._authTypeSelectBox.value) : undefined; let currentAuthType = this._authTypeSelectBox ? this.getMatchingAuthType(this._authTypeSelectBox.value) : undefined;
if (!currentAuthType || currentAuthType.showUsernameAndPassword) { if (!currentAuthType || currentAuthType === AuthenticationType.SqlLogin) {
if (!value && isOptionRequired) { if (!value && isOptionRequired) {
return true; return true;
} }
@@ -254,9 +269,9 @@ export class ConnectionWidget {
return button; return button;
} }
private appendCheckbox(container: Builder, label: string, checkboxClass: string, cellContainerClass: string, isChecked: boolean): Checkbox { private appendCheckbox(container: Builder, label: string, checkboxClass: string, cellContainerClass: string, rowContainerClass: string, isChecked: boolean): Checkbox {
let checkbox: Checkbox; let checkbox: Checkbox;
container.element('tr', {}, (rowContainer) => { container.element('tr', { class: rowContainerClass }, (rowContainer) => {
rowContainer.element('td'); rowContainer.element('td');
rowContainer.element('td', { class: cellContainerClass }, (inputCellContainer) => { rowContainer.element('td', { class: cellContainerClass }, (inputCellContainer) => {
checkbox = new Checkbox(inputCellContainer.getHTMLElement(), { label, checked: isChecked, ariaLabel: label }); checkbox = new Checkbox(inputCellContainer.getHTMLElement(), { label, checked: isChecked, ariaLabel: label });
@@ -275,6 +290,7 @@ export class ConnectionWidget {
this._toDispose.push(styler.attachSelectBoxStyler(this._serverGroupSelectBox, this._themeService)); this._toDispose.push(styler.attachSelectBoxStyler(this._serverGroupSelectBox, this._themeService));
this._toDispose.push(attachButtonStyler(this._advancedButton, this._themeService)); this._toDispose.push(attachButtonStyler(this._advancedButton, this._themeService));
this._toDispose.push(attachCheckboxStyler(this._rememberPasswordCheckBox, this._themeService)); this._toDispose.push(attachCheckboxStyler(this._rememberPasswordCheckBox, this._themeService));
this._toDispose.push(styler.attachSelectBoxStyler(this._azureAccountDropdown, this._themeService));
if (this._authTypeSelectBox) { if (this._authTypeSelectBox) {
// Theme styler // Theme styler
@@ -285,6 +301,23 @@ export class ConnectionWidget {
})); }));
} }
if (this._azureAccountDropdown) {
this._toDispose.push(styler.attachSelectBoxStyler(this._azureAccountDropdown, this._themeService));
this._toDispose.push(this._azureAccountDropdown.onDidSelect(() => {
this.onAzureAccountSelected();
}));
}
if (this._refreshCredentialsLinkBuilder) {
this._toDispose.push(this._refreshCredentialsLinkBuilder.on(DOM.EventType.CLICK, async () => {
let account = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value);
if (account) {
await this._accountManagementService.refreshAccount(account);
this.fillInAzureAccountOptions();
}
}));
}
this._toDispose.push(this._serverGroupSelectBox.onDidSelect(selectedGroup => { this._toDispose.push(this._serverGroupSelectBox.onDidSelect(selectedGroup => {
this.onGroupSelected(selectedGroup.selected); this.onGroupSelected(selectedGroup.selected);
})); }));
@@ -342,7 +375,7 @@ export class ConnectionWidget {
private setConnectButton(): void { private setConnectButton(): void {
let showUsernameAndPassword: boolean = true; let showUsernameAndPassword: boolean = true;
if (this.authType) { if (this.authType) {
showUsernameAndPassword = this.authType.showUsernameAndPassword; showUsernameAndPassword = this.authType === AuthenticationType.SqlLogin;
} }
showUsernameAndPassword ? this._callbacks.onSetConnectButton(!!this.serverName && !!this.userName) : showUsernameAndPassword ? this._callbacks.onSetConnectButton(!!this.serverName && !!this.userName) :
this._callbacks.onSetConnectButton(!!this.serverName); this._callbacks.onSetConnectButton(!!this.serverName);
@@ -350,7 +383,7 @@ export class ConnectionWidget {
private onAuthTypeSelected(selectedAuthType: string) { private onAuthTypeSelected(selectedAuthType: string) {
let currentAuthType = this.getMatchingAuthType(selectedAuthType); let currentAuthType = this.getMatchingAuthType(selectedAuthType);
if (!currentAuthType.showUsernameAndPassword) { if (currentAuthType !== AuthenticationType.SqlLogin) {
this._userNameInputBox.disable(); this._userNameInputBox.disable();
this._passwordInputBox.disable(); this._passwordInputBox.disable();
this._userNameInputBox.hideMessage(); this._userNameInputBox.hideMessage();
@@ -366,6 +399,68 @@ export class ConnectionWidget {
this._passwordInputBox.enable(); this._passwordInputBox.enable();
this._rememberPasswordCheckBox.enabled = true; this._rememberPasswordCheckBox.enabled = true;
} }
if (currentAuthType === AuthenticationType.AzureMFA) {
this.fillInAzureAccountOptions();
this._azureAccountDropdown.enable();
let tableContainer = this._tableContainer.getContainer();
tableContainer.classList.add('hide-username-password');
tableContainer.classList.remove('hide-azure-accounts');
} else {
this._azureAccountDropdown.disable();
let tableContainer = this._tableContainer.getContainer();
tableContainer.classList.remove('hide-username-password');
tableContainer.classList.add('hide-azure-accounts');
this._azureAccountDropdown.hideMessage();
}
}
private async fillInAzureAccountOptions(): Promise<void> {
let oldSelection = this._azureAccountDropdown.value;
this._azureAccountList = await this._accountManagementService.getAccountsForProvider(this._azureProviderId);
let accountDropdownOptions = this._azureAccountList.map(account => account.key.accountId);
if (accountDropdownOptions.length === 0) {
// If there are no accounts add a blank option so that add account isn't automatically selected
accountDropdownOptions.unshift('');
}
accountDropdownOptions.push(this._addAzureAccountMessage);
this._azureAccountDropdown.setOptions(accountDropdownOptions);
this._azureAccountDropdown.selectWithOptionName(oldSelection);
this.updateRefreshCredentialsLink();
}
private async updateRefreshCredentialsLink(): Promise<void> {
let chosenAccount = this._azureAccountList.find(account => account.key.accountId === this._azureAccountDropdown.value);
if (chosenAccount && chosenAccount.isStale) {
this._tableContainer.getContainer().classList.remove('hide-refresh-link');
} else {
this._tableContainer.getContainer().classList.add('hide-refresh-link');
}
}
private async onAzureAccountSelected(): Promise<void> {
// Reset the dropdown's validation message if the old selection was not valid but the new one is
this.validateAzureAccountSelection(false);
this._refreshCredentialsLinkBuilder.display('none');
// Open the add account dialog if needed, then select the added account
if (this._azureAccountDropdown.value === this._addAzureAccountMessage) {
let oldAccountIds = this._azureAccountList.map(account => account.key.accountId);
await this._accountManagementService.addAccount(this._azureProviderId);
// Refresh the dropdown's list to include the added account
await this.fillInAzureAccountOptions();
// If a new account was added find it and select it, otherwise select the first account
let newAccount = this._azureAccountList.find(option => !oldAccountIds.some(oldId => oldId === option.key.accountId));
if (newAccount) {
this._azureAccountDropdown.selectWithOptionName(newAccount.key.accountId);
} else {
this._azureAccountDropdown.select(0);
}
}
this.updateRefreshCredentialsLink();
} }
private serverNameChanged(serverName: string) { private serverNameChanged(serverName: string) {
@@ -407,6 +502,7 @@ export class ConnectionWidget {
private clearValidationMessages(): void { private clearValidationMessages(): void {
this._serverNameInputBox.hideMessage(); this._serverNameInputBox.hideMessage();
this._userNameInputBox.hideMessage(); this._userNameInputBox.hideMessage();
this._azureAccountDropdown.hideMessage();
} }
private getModelValue(value: string): string { private getModelValue(value: string): string {
@@ -449,8 +545,8 @@ export class ConnectionWidget {
if (this._authTypeSelectBox) { if (this._authTypeSelectBox) {
this.onAuthTypeSelected(this._authTypeSelectBox.value); this.onAuthTypeSelected(this._authTypeSelectBox.value);
} }
// Disable connect button if - // Disable connect button if -
// 1. Authentication type is SQL Login and no username is provided // 1. Authentication type is SQL Login and no username is provided
// 2. No server name is provided // 2. No server name is provided
@@ -513,7 +609,7 @@ export class ConnectionWidget {
currentAuthType = this.getMatchingAuthType(this._authTypeSelectBox.value); currentAuthType = this.getMatchingAuthType(this._authTypeSelectBox.value);
} }
if (!currentAuthType || currentAuthType.showUsernameAndPassword) { if (!currentAuthType || currentAuthType === AuthenticationType.SqlLogin) {
this._userNameInputBox.enable(); this._userNameInputBox.enable();
this._passwordInputBox.enable(); this._passwordInputBox.enable();
this._rememberPasswordCheckBox.enabled = true; this._rememberPasswordCheckBox.enabled = true;
@@ -537,7 +633,7 @@ export class ConnectionWidget {
} }
public get userName(): string { public get userName(): string {
return this._userNameInputBox.value; return this.authenticationType === AuthenticationType.AzureMFA ? this._azureAccountDropdown.value : this._userNameInputBox.value;
} }
public get password(): string { public get password(): string {
@@ -548,6 +644,27 @@ export class ConnectionWidget {
return this._authTypeSelectBox ? this.getAuthTypeName(this._authTypeSelectBox.value) : undefined; return this._authTypeSelectBox ? this.getAuthTypeName(this._authTypeSelectBox.value) : undefined;
} }
private validateAzureAccountSelection(showMessage: boolean = true): boolean {
if (this.authType !== AuthenticationType.AzureMFA) {
return true;
}
let selected = this._azureAccountDropdown.value;
if (selected === '' || selected === this._addAzureAccountMessage) {
if (showMessage) {
this._azureAccountDropdown.showMessage({
content: localize('connectionWidget.invalidAzureAccount', 'You must select an account'),
type: MessageType.ERROR
});
}
return false;
} else {
this._azureAccountDropdown.hideMessage();
}
return true;
}
private validateInputs(): boolean { private validateInputs(): boolean {
let isFocused = false; let isFocused = false;
let validateServerName = this._serverNameInputBox.validate(); let validateServerName = this._serverNameInputBox.validate();
@@ -565,7 +682,12 @@ export class ConnectionWidget {
this._passwordInputBox.focus(); this._passwordInputBox.focus();
isFocused = true; isFocused = true;
} }
return validateServerName && validateUserName && validatePassword; let validateAzureAccount = this.validateAzureAccountSelection();
if (!validateAzureAccount && !isFocused) {
this._azureAccountDropdown.focus();
isFocused = true;
}
return validateServerName && validateUserName && validatePassword && validateAzureAccount;
} }
public connect(model: IConnectionProfile): boolean { public connect(model: IConnectionProfile): boolean {
@@ -613,7 +735,7 @@ export class ConnectionWidget {
private getMatchingAuthType(displayName: string): AuthenticationType { private getMatchingAuthType(displayName: string): AuthenticationType {
const authType = this._authTypeMap[this._providerName]; const authType = this._authTypeMap[this._providerName];
return authType ? authType.find(authType => this.getAuthTypeDisplayName(authType.name) === displayName) : undefined; return authType ? authType.find(authType => this.getAuthTypeDisplayName(authType) === displayName) : undefined;
} }
public closeDatabaseDropdown(): void { public closeDatabaseDropdown(): void {
@@ -634,18 +756,14 @@ export class ConnectionWidget {
} }
private focusPasswordIfNeeded(): void { private focusPasswordIfNeeded(): void {
if (this.authType && this.authType.showUsernameAndPassword && this.userName && !this.password) { if (this.authType && this.authType === AuthenticationType.SqlLogin && this.userName && !this.password) {
this._passwordInputBox.focus(); this._passwordInputBox.focus();
} }
} }
} }
class AuthenticationType { enum AuthenticationType {
public name: string; SqlLogin = 'SqlLogin',
public showUsernameAndPassword: boolean; Integrated = 'Integrated',
AzureMFA = 'AzureMFA'
constructor(name: string, showUsernameAndPassword: boolean) {
this.name = name;
this.showUsernameAndPassword = showUsernameAndPassword;
}
} }

View File

@@ -28,11 +28,12 @@
overflow: hidden; overflow: hidden;
margin: 0px 11px; margin: 0px 11px;
} }
.connection-dialog .tabBody { .connection-dialog .tabBody {
overflow: hidden; overflow: hidden;
flex: 1 1; flex: 1 1;
display: flex; display: flex;
flex-direction: column; flex-direction: column;
} }
.connection-recent, .connection-saved { .connection-recent, .connection-saved {
@@ -115,3 +116,15 @@
padding: 5px 15px; padding: 5px 15px;
font-weight: 600; font-weight: 600;
} }
.hide-azure-accounts .azure-account-row {
display: none;
}
.hide-username-password .username-password-row {
display: none;
}
.hide-refresh-link .azure-account-row.refresh-credentials-link {
display: none;
}

View File

@@ -261,7 +261,7 @@ export class ObjectExplorerService implements IObjectExplorerService {
return this._activeObjectExplorerNodes[connection.id]; return this._activeObjectExplorerNodes[connection.id];
} }
public createNewSession(providerId: string, connection: ConnectionProfile): Thenable<sqlops.ObjectExplorerSessionResponse> { public async createNewSession(providerId: string, connection: ConnectionProfile): Promise<sqlops.ObjectExplorerSessionResponse> {
let self = this; let self = this;
return new Promise<sqlops.ObjectExplorerSessionResponse>((resolve, reject) => { return new Promise<sqlops.ObjectExplorerSessionResponse>((resolve, reject) => {
let provider = this._providers[providerId]; let provider = this._providers[providerId];

View File

@@ -20,7 +20,7 @@ import { AccountDialogController } from 'sql/parts/accountManagement/accountDial
import { AutoOAuthDialogController } from 'sql/parts/accountManagement/autoOAuthDialog/autoOAuthDialogController'; import { AutoOAuthDialogController } from 'sql/parts/accountManagement/autoOAuthDialog/autoOAuthDialogController';
import { AccountListStatusbarItem } from 'sql/parts/accountManagement/accountListStatusbar/accountListStatusbarItem'; import { AccountListStatusbarItem } from 'sql/parts/accountManagement/accountListStatusbar/accountListStatusbarItem';
import { AccountProviderAddedEventParams, UpdateAccountListEventParams } from 'sql/services/accountManagement/eventTypes'; import { AccountProviderAddedEventParams, UpdateAccountListEventParams } from 'sql/services/accountManagement/eventTypes';
import { IAccountManagementService } from 'sql/services/accountManagement/interfaces'; import { IAccountManagementService, AzureResource } from 'sql/services/accountManagement/interfaces';
import { IClipboardService } from 'vs/platform/clipboard/common/clipboardService'; import { IClipboardService } from 'vs/platform/clipboard/common/clipboardService';
export class AccountManagementService implements IAccountManagementService { export class AccountManagementService implements IAccountManagementService {
@@ -217,11 +217,12 @@ export class AccountManagementService implements IAccountManagementService {
/** /**
* Generates a security token by asking the account's provider * Generates a security token by asking the account's provider
* @param {Account} account Account to generate security token for * @param {Account} account Account to generate security token for
* @param {AzureResource} resource The resource to get the security token for
* @return {Thenable<{}>} Promise to return the security token * @return {Thenable<{}>} Promise to return the security token
*/ */
public getSecurityToken(account: sqlops.Account): Thenable<{}> { public getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}> {
return this.doWithProvider(account.key.providerId, provider => { return this.doWithProvider(account.key.providerId, provider => {
return provider.provider.getSecurityToken(account); return provider.provider.getSecurityToken(account, resource);
}); });
} }

View File

@@ -22,7 +22,7 @@ export interface IAccountManagementService {
addAccount(providerId: string): Thenable<void>; addAccount(providerId: string): Thenable<void>;
getAccountProviderMetadata(): Thenable<sqlops.AccountProviderMetadata[]>; getAccountProviderMetadata(): Thenable<sqlops.AccountProviderMetadata[]>;
getAccountsForProvider(providerId: string): Thenable<sqlops.Account[]>; getAccountsForProvider(providerId: string): Thenable<sqlops.Account[]>;
getSecurityToken(account: sqlops.Account): Thenable<{}>; getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}>;
removeAccount(accountKey: sqlops.AccountKey): Thenable<boolean>; removeAccount(accountKey: sqlops.AccountKey): Thenable<boolean>;
refreshAccount(account: sqlops.Account): Thenable<sqlops.Account>; refreshAccount(account: sqlops.Account): Thenable<sqlops.Account>;
@@ -44,6 +44,12 @@ export interface IAccountManagementService {
readonly updateAccountListEvent: Event<UpdateAccountListEventParams>; readonly updateAccountListEvent: Event<UpdateAccountListEventParams>;
} }
// Enum matching the AzureResource enum from sqlops.d.ts
export enum AzureResource {
ResourceManagement = 0,
Sql = 1
}
export interface IAccountStore { export interface IAccountStore {
/** /**
* Adds the provided account if the account doesn't exist. Updates the account if it already exists * Adds the provided account if the account doesn't exist. Updates the account if it already exists

10
src/sql/sqlops.d.ts vendored
View File

@@ -1915,7 +1915,7 @@ declare module 'sqlops' {
* @param {Account} account Account to generate security token for * @param {Account} account Account to generate security token for
* @return {Thenable<{}>} Promise to return the security token * @return {Thenable<{}>} Promise to return the security token
*/ */
export function getSecurityToken(account: Account): Thenable<{}>; export function getSecurityToken(account: Account, resource: AzureResource): Thenable<{}>;
/** /**
* An [event](#Event) which fires when the accounts have changed. * An [event](#Event) which fires when the accounts have changed.
@@ -1988,6 +1988,11 @@ declare module 'sqlops' {
isStale: boolean; isStale: boolean;
} }
export enum AzureResource {
ResourceManagement = 0,
Sql = 1
}
export interface DidChangeAccountsParams { export interface DidChangeAccountsParams {
// Updated accounts // Updated accounts
accounts: Account[]; accounts: Account[];
@@ -2045,9 +2050,10 @@ declare module 'sqlops' {
/** /**
* Generates a security token for the provided account * Generates a security token for the provided account
* @param {Account} account The account to generate a security token for * @param {Account} account The account to generate a security token for
* @param {AzureResource} resource The resource to get the token for
* @return {Thenable<{}>} Promise to return a security token object * @return {Thenable<{}>} Promise to return a security token object
*/ */
getSecurityToken(account: Account): Thenable<{}>; getSecurityToken(account: Account, resource: AzureResource): Thenable<{}>;
/** /**
* Prompts the user to enter account information. * Prompts the user to enter account information.

View File

@@ -313,6 +313,11 @@ export class TreeComponentItem extends TreeItem {
checked?: boolean; checked?: boolean;
} }
export enum AzureResource {
ResourceManagement = 0,
Sql = 1
}
export class SqlThemeIcon { export class SqlThemeIcon {
static readonly Folder = new SqlThemeIcon('Folder'); static readonly Folder = new SqlThemeIcon('Folder');

View File

@@ -89,12 +89,12 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape {
return Promise.all(promises).then(() => resultAccounts); return Promise.all(promises).then(() => resultAccounts);
} }
public $getSecurityToken(account: sqlops.Account): Thenable<{}> { public $getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}> {
return this.$getAllAccounts().then(() => { return this.$getAllAccounts().then(() => {
for (const handle in this._accounts) { for (const handle in this._accounts) {
const providerHandle = parseInt(handle); const providerHandle = parseInt(handle);
if (this._accounts[handle].findIndex((acct) => acct.key.accountId === account.key.accountId) !== -1) { if (this._accounts[handle].findIndex((acct) => acct.key.accountId === account.key.accountId) !== -1) {
return this._withProvider(providerHandle, (provider: sqlops.AccountProvider) => provider.getSecurityToken(account)); return this._withProvider(providerHandle, (provider: sqlops.AccountProvider) => provider.getSecurityToken(account, resource));
} }
} }

View File

@@ -76,8 +76,8 @@ export class MainThreadAccountManagement implements MainThreadAccountManagementS
clear(accountKey: sqlops.AccountKey): Thenable<void> { clear(accountKey: sqlops.AccountKey): Thenable<void> {
return self._proxy.$clear(handle, accountKey); return self._proxy.$clear(handle, accountKey);
}, },
getSecurityToken(account: sqlops.Account): Thenable<{}> { getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}> {
return self._proxy.$getSecurityToken(account); return self._proxy.$getSecurityToken(account, resource);
}, },
initialize(restoredAccounts: sqlops.Account[]): Thenable<sqlops.Account[]> { initialize(restoredAccounts: sqlops.Account[]): Thenable<sqlops.Account[]> {
return self._proxy.$initialize(handle, restoredAccounts); return self._proxy.$initialize(handle, restoredAccounts);

View File

@@ -97,8 +97,8 @@ export function createApiFactory(
getAllAccounts(): Thenable<sqlops.Account[]> { getAllAccounts(): Thenable<sqlops.Account[]> {
return extHostAccountManagement.$getAllAccounts(); return extHostAccountManagement.$getAllAccounts();
}, },
getSecurityToken(account: sqlops.Account): Thenable<{}> { getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}> {
return extHostAccountManagement.$getSecurityToken(account); return extHostAccountManagement.$getSecurityToken(account, resource);
}, },
onDidChangeAccounts(listener: (e: sqlops.DidChangeAccountsParams) => void, thisArgs?: any, disposables?: extHostTypes.Disposable[]) { onDidChangeAccounts(listener: (e: sqlops.DidChangeAccountsParams) => void, thisArgs?: any, disposables?: extHostTypes.Disposable[]) {
return extHostAccountManagement.onDidChangeAccounts(listener, thisArgs, disposables); return extHostAccountManagement.onDidChangeAccounts(listener, thisArgs, disposables);
@@ -452,7 +452,8 @@ export function createApiFactory(
Orientation: sqlExtHostTypes.Orientation, Orientation: sqlExtHostTypes.Orientation,
SqlThemeIcon: sqlExtHostTypes.SqlThemeIcon, SqlThemeIcon: sqlExtHostTypes.SqlThemeIcon,
TreeComponentItem: sqlExtHostTypes.TreeComponentItem, TreeComponentItem: sqlExtHostTypes.TreeComponentItem,
nb: nb nb: nb,
AzureResource: sqlExtHostTypes.AzureResource
}; };
} }
}; };

View File

@@ -27,7 +27,7 @@ import {
export abstract class ExtHostAccountManagementShape { export abstract class ExtHostAccountManagementShape {
$autoOAuthCancelled(handle: number): Thenable<void> { throw ni(); } $autoOAuthCancelled(handle: number): Thenable<void> { throw ni(); }
$clear(handle: number, accountKey: sqlops.AccountKey): Thenable<void> { throw ni(); } $clear(handle: number, accountKey: sqlops.AccountKey): Thenable<void> { throw ni(); }
$getSecurityToken(account: sqlops.Account): Thenable<{}> { throw ni(); } $getSecurityToken(account: sqlops.Account, resource?: sqlops.AzureResource): Thenable<{}> { throw ni(); }
$initialize(handle: number, restoredAccounts: sqlops.Account[]): Thenable<sqlops.Account[]> { throw ni(); } $initialize(handle: number, restoredAccounts: sqlops.Account[]): Thenable<sqlops.Account[]> { throw ni(); }
$prompt(handle: number): Thenable<sqlops.Account> { throw ni(); } $prompt(handle: number): Thenable<sqlops.Account> { throw ni(); }
$refresh(handle: number, account: sqlops.Account): Thenable<sqlops.Account> { throw ni(); } $refresh(handle: number, account: sqlops.Account): Thenable<sqlops.Account> { throw ni(); }

View File

@@ -136,7 +136,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and call create firewall rule in resource provider // Then: it should get security token from account management service and call create firewall rule in resource provider
deferredPromise.promise.then(() => { deferredPromise.promise.then(() => {
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny()), TypeMoq.Times.once()); mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockFirewallRuleDialog.verify(x => x.close(), TypeMoq.Times.once()); mockFirewallRuleDialog.verify(x => x.close(), TypeMoq.Times.once());
mockFirewallRuleDialog.verify(x => x.onServiceComplete(), TypeMoq.Times.once()); mockFirewallRuleDialog.verify(x => x.onServiceComplete(), TypeMoq.Times.once());
@@ -165,7 +165,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and an error dialog should have been opened // Then: it should get security token from account management service and an error dialog should have been opened
deferredPromise.promise.then(() => { deferredPromise.promise.then(() => {
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny()), TypeMoq.Times.once()); mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.never()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.never());
done(); done();
@@ -193,7 +193,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and an error dialog should have been opened // Then: it should get security token from account management service and an error dialog should have been opened
deferredPromise.promise.then(() => { deferredPromise.promise.then(() => {
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny()), TypeMoq.Times.once()); mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
done(); done();
@@ -221,7 +221,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and an error dialog should have been opened // Then: it should get security token from account management service and an error dialog should have been opened
deferredPromise.promise.then(() => { deferredPromise.promise.then(() => {
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny()), TypeMoq.Times.once()); mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
done(); done();
@@ -232,7 +232,7 @@ suite('Firewall rule dialog controller tests', () => {
function getMockAccountManagementService(resolveSecurityToken: boolean): TypeMoq.Mock<AccountManagementTestService> { function getMockAccountManagementService(resolveSecurityToken: boolean): TypeMoq.Mock<AccountManagementTestService> {
let accountManagementTestService = new AccountManagementTestService(); let accountManagementTestService = new AccountManagementTestService();
let mockAccountManagementService = TypeMoq.Mock.ofInstance(accountManagementTestService); let mockAccountManagementService = TypeMoq.Mock.ofInstance(accountManagementTestService);
mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny())) mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()))
.returns(() => resolveSecurityToken ? Promise.resolve({}) : Promise.reject(null).then()); .returns(() => resolveSecurityToken ? Promise.resolve({}) : Promise.reject(null).then());
return mockAccountManagementService; return mockAccountManagementService;
} }

View File

@@ -34,6 +34,7 @@ import * as assert from 'assert';
import * as TypeMoq from 'typemoq'; import * as TypeMoq from 'typemoq';
import { IConnectionProfileGroup, ConnectionProfileGroup } from 'sql/parts/connection/common/connectionProfileGroup'; import { IConnectionProfileGroup, ConnectionProfileGroup } from 'sql/parts/connection/common/connectionProfileGroup';
import { ConnectionProfile } from 'sql/parts/connection/common/connectionProfile'; import { ConnectionProfile } from 'sql/parts/connection/common/connectionProfile';
import { AccountManagementTestService } from 'sqltest/stubs/accountManagementStubs';
suite('SQL ConnectionManagementService tests', () => { suite('SQL ConnectionManagementService tests', () => {
@@ -46,6 +47,7 @@ suite('SQL ConnectionManagementService tests', () => {
let mssqlConnectionProvider: TypeMoq.Mock<ConnectionProviderStub>; let mssqlConnectionProvider: TypeMoq.Mock<ConnectionProviderStub>;
let workspaceConfigurationServiceMock: TypeMoq.Mock<WorkspaceConfigurationTestService>; let workspaceConfigurationServiceMock: TypeMoq.Mock<WorkspaceConfigurationTestService>;
let resourceProviderStubMock: TypeMoq.Mock<ResourceProviderStub>; let resourceProviderStubMock: TypeMoq.Mock<ResourceProviderStub>;
let accountManagementService: TypeMoq.Mock<AccountManagementTestService>;
let none: void; let none: void;
@@ -88,6 +90,7 @@ suite('SQL ConnectionManagementService tests', () => {
mssqlConnectionProvider = TypeMoq.Mock.ofType(ConnectionProviderStub); mssqlConnectionProvider = TypeMoq.Mock.ofType(ConnectionProviderStub);
let resourceProviderStub = new ResourceProviderStub(); let resourceProviderStub = new ResourceProviderStub();
resourceProviderStubMock = TypeMoq.Mock.ofInstance(resourceProviderStub); resourceProviderStubMock = TypeMoq.Mock.ofInstance(resourceProviderStub);
accountManagementService = TypeMoq.Mock.ofType(AccountManagementTestService);
let root = new ConnectionProfileGroup(ConnectionProfileGroup.RootGroupName, undefined, ConnectionProfileGroup.RootGroupName, undefined, undefined); let root = new ConnectionProfileGroup(ConnectionProfileGroup.RootGroupName, undefined, ConnectionProfileGroup.RootGroupName, undefined, undefined);
root.connections = [ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile)]; root.connections = [ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile)];
@@ -162,7 +165,8 @@ suite('SQL ConnectionManagementService tests', () => {
undefined, undefined,
resourceProviderStubMock.object, resourceProviderStubMock.object,
undefined, undefined,
undefined undefined,
accountManagementService.object
); );
return connectionManagementService; return connectionManagementService;
} }
@@ -837,4 +841,44 @@ suite('SQL ConnectionManagementService tests', () => {
// Then undefined is returned // Then undefined is returned
assert.equal(foundUri, undefined); assert.equal(foundUri, undefined);
}); });
test('addSavedPassword fills in Azure access tokens for Azure accounts', async () => {
// Set up a connection profile that uses Azure
let azureConnectionProfile = ConnectionProfile.fromIConnectionProfile(capabilitiesService, connectionProfile);
azureConnectionProfile.authenticationType = 'AzureMFA';
let username = 'testuser@microsoft.com';
azureConnectionProfile.userName = username;
let servername = 'test-database.database.windows.net';
azureConnectionProfile.serverName = servername;
// Set up the account management service to return a token for the given user
accountManagementService.setup(x => x.getAccountsForProvider(TypeMoq.It.isAny())).returns(providerId => Promise.resolve<sqlops.Account[]>([
{
key: {
accountId: username,
providerId: providerId
},
displayInfo: undefined,
isStale: false,
properties: undefined
}
]));
let testToken = 'testToken';
accountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
azurePublicCloud: {
token: testToken
}
}));
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({
profile: profile,
savedCred: false
}));
// If I call addSavedPassword
let profileWithCredentials = await connectionManagementService.addSavedPassword(azureConnectionProfile);
// Then the returned profile has the account token set
assert.equal(profileWithCredentials.userName, username);
assert.equal(profileWithCredentials.options['azureAccountToken'], testToken);
});
}); });

View File

@@ -50,7 +50,7 @@ export class AccountManagementTestService implements IAccountManagementService {
return undefined; return undefined;
} }
getSecurityToken(account: sqlops.Account): Thenable<{}> { getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}> {
return undefined; return undefined;
} }
@@ -88,7 +88,7 @@ export class AccountProviderStub implements sqlops.AccountProvider {
return Promise.resolve(); return Promise.resolve();
} }
getSecurityToken(account: sqlops.Account): Thenable<{}> { getSecurityToken(account: sqlops.Account, resource: sqlops.AzureResource): Thenable<{}> {
return Promise.resolve({}); return Promise.resolve({});
} }

View File

@@ -15,7 +15,7 @@ import { TestInstantiationService } from 'vs/platform/instantiation/test/common/
import { IRPCProtocol } from 'vs/workbench/services/extensions/node/proxyIdentifier'; import { IRPCProtocol } from 'vs/workbench/services/extensions/node/proxyIdentifier';
import { SqlMainContext } from 'sql/workbench/api/node/sqlExtHost.protocol'; import { SqlMainContext } from 'sql/workbench/api/node/sqlExtHost.protocol';
import { MainThreadAccountManagement } from 'sql/workbench/api/node/mainThreadAccountManagement'; import { MainThreadAccountManagement } from 'sql/workbench/api/node/mainThreadAccountManagement';
import { IAccountManagementService } from 'sql/services/accountManagement/interfaces'; import { IAccountManagementService, AzureResource } from 'sql/services/accountManagement/interfaces';
import { createDecorator } from 'vs/platform/instantiation/common/instantiation'; import { createDecorator } from 'vs/platform/instantiation/common/instantiation';
const IRPCProtocol = createDecorator<IRPCProtocol>('rpcProtocol'); const IRPCProtocol = createDecorator<IRPCProtocol>('rpcProtocol');
@@ -366,7 +366,7 @@ suite('ExtHostAccountManagement', () => {
extHost.$getAllAccounts() extHost.$getAllAccounts()
.then((accounts) => { .then((accounts) => {
// If: I get security token it will not throw // If: I get security token it will not throw
return extHost.$getSecurityToken(mockAccount1); return extHost.$getSecurityToken(mockAccount1, AzureResource.ResourceManagement);
} }
).then(() => done(), (err) => done(new Error(err))); ).then(() => done(), (err) => done(new Error(err)));
}); });
@@ -417,7 +417,7 @@ suite('ExtHostAccountManagement', () => {
extHost.$getAllAccounts() extHost.$getAllAccounts()
.then(accounts => { .then(accounts => {
return extHost.$getSecurityToken(mockAccount2); return extHost.$getSecurityToken(mockAccount2, AzureResource.ResourceManagement);
}) })
.then((noError) => { .then((noError) => {
done(new Error('Expected getSecurityToken to throw')); done(new Error('Expected getSecurityToken to throw'));
@@ -447,7 +447,7 @@ function getMockAccountManagementService(accounts: sqlops.Account[]): TypeMoq.Mo
mockAccountManagementService.setup(x => x.getAccountsForProvider(TypeMoq.It.isAny())) mockAccountManagementService.setup(x => x.getAccountsForProvider(TypeMoq.It.isAny()))
.returns(() => Promise.resolve(accounts)); .returns(() => Promise.resolve(accounts));
mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isValue(accounts[0]))) mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isValue(accounts[0]), TypeMoq.It.isAny()))
.returns(() => Promise.resolve({})); .returns(() => Promise.resolve({}));
mockAccountManagementService.setup(x => x.updateAccountListEvent) mockAccountManagementService.setup(x => x.updateAccountListEvent)
.returns(() => () => { return undefined; } ); .returns(() => () => { return undefined; } );

View File

@@ -557,6 +557,8 @@ export class Workbench extends Disposable implements IPartService {
serviceCollection.set(IServerGroupController, this.instantiationService.createInstance(ServerGroupController)); serviceCollection.set(IServerGroupController, this.instantiationService.createInstance(ServerGroupController));
serviceCollection.set(ICredentialsService, this.instantiationService.createInstance(CredentialsService)); serviceCollection.set(ICredentialsService, this.instantiationService.createInstance(CredentialsService));
serviceCollection.set(IResourceProviderService, this.instantiationService.createInstance(ResourceProviderService)); serviceCollection.set(IResourceProviderService, this.instantiationService.createInstance(ResourceProviderService));
let accountManagementService = this.instantiationService.createInstance(AccountManagementService, undefined);
serviceCollection.set(IAccountManagementService, accountManagementService);
let connectionManagementService = this.instantiationService.createInstance(ConnectionManagementService, undefined, undefined); let connectionManagementService = this.instantiationService.createInstance(ConnectionManagementService, undefined, undefined);
serviceCollection.set(IConnectionManagementService, connectionManagementService); serviceCollection.set(IConnectionManagementService, connectionManagementService);
serviceCollection.set(ISerializationService, this.instantiationService.createInstance(SerializationService)); serviceCollection.set(ISerializationService, this.instantiationService.createInstance(SerializationService));
@@ -577,8 +579,6 @@ export class Workbench extends Disposable implements IPartService {
serviceCollection.set(IFileBrowserService, this.instantiationService.createInstance(FileBrowserService)); serviceCollection.set(IFileBrowserService, this.instantiationService.createInstance(FileBrowserService));
serviceCollection.set(IFileBrowserDialogController, this.instantiationService.createInstance(FileBrowserDialogController)); serviceCollection.set(IFileBrowserDialogController, this.instantiationService.createInstance(FileBrowserDialogController));
serviceCollection.set(IInsightsDialogService, this.instantiationService.createInstance(InsightsDialogService)); serviceCollection.set(IInsightsDialogService, this.instantiationService.createInstance(InsightsDialogService));
let accountManagementService = this.instantiationService.createInstance(AccountManagementService, undefined);
serviceCollection.set(IAccountManagementService, accountManagementService);
let notebookService = this.instantiationService.createInstance(NotebookService); let notebookService = this.instantiationService.createInstance(NotebookService);
serviceCollection.set(INotebookService, notebookService); serviceCollection.set(INotebookService, notebookService);
serviceCollection.set(IAccountPickerService, this.instantiationService.createInstance(AccountPickerService)); serviceCollection.set(IAccountPickerService, this.instantiationService.createInstance(AccountPickerService));