Add MSAL Authentication Library support (#21024)

This commit is contained in:
Christopher Suh
2022-11-23 17:06:44 -05:00
committed by GitHub
parent fba47815e2
commit 86c3f315f2
32 changed files with 1502 additions and 320 deletions

View File

@@ -50,6 +50,7 @@ import { Tenant, TenantListDelegate, TenantListRenderer } from 'sql/workbench/se
import { IAccountManagementService } from 'sql/platform/accounts/common/interfaces';
export const VIEWLET_ID = 'workbench.view.accountpanel';
export type AuthLibrary = 'ADAL' | 'MSAL';
export class AccountPaneContainer extends ViewPaneContainer {
@@ -376,9 +377,14 @@ export class AccountDialog extends Modal {
this._splitView!.layout(DOM.getContentHeight(this._container!));
// Set the initial items of the list
providerView.updateAccounts(newProvider.initialAccounts);
const authLibrary: AuthLibrary = this._configurationService.getValue('azure.authenticationLibrary');
let updatedAccounts: azdata.Account[];
if (authLibrary) {
updatedAccounts = filterAccounts(newProvider.initialAccounts, authLibrary);
}
providerView.updateAccounts(updatedAccounts);
if (newProvider.initialAccounts.length > 0 && this._splitViewContainer!.hidden) {
if (updatedAccounts.length > 0 && this._splitViewContainer!.hidden) {
this.showSplitView();
}
@@ -413,7 +419,12 @@ export class AccountDialog extends Modal {
if (!providerMapping || !providerMapping.view) {
return;
}
providerMapping.view.updateAccounts(args.accountList);
const authLibrary: AuthLibrary = this._configurationService.getValue('azure.authenticationLibrary');
let updatedAccounts: azdata.Account[];
if (authLibrary) {
updatedAccounts = filterAccounts(args.accountList, authLibrary);
}
providerMapping.view.updateAccounts(updatedAccounts);
if (args.accountList.length > 0 && this._splitViewContainer!.hidden) {
this.showSplitView();
@@ -480,3 +491,27 @@ export class AccountDialog extends Modal {
v.addAccountAction.run();
}
}
// Filter accounts based on currently selected Auth Library:
// if the account key is present, filter based on current auth library
// if there is no account key (pre-MSAL account), then it is an ADAL account and
// should be displayed as long as ADAL is the currently selected auth library
export function filterAccounts(accounts: azdata.Account[], authLibrary: AuthLibrary): azdata.Account[] {
let filteredAccounts = accounts.filter(account => {
if (account.key.authLibrary) {
if (account.key.authLibrary === authLibrary) {
return true;
} else {
return false;
}
} else {
if (authLibrary === 'ADAL') {
return true;
} else {
return false;
}
}
});
return filteredAccounts;
}

View File

@@ -24,6 +24,9 @@ import { values } from 'vs/base/common/collections';
import { ILogService } from 'vs/platform/log/common/log';
import { INotificationService, Severity, INotification } from 'vs/platform/notification/common/notification';
import { Action } from 'vs/base/common/actions';
import { DisposableStore } from 'vs/base/common/lifecycle';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
import { AuthLibrary, filterAccounts } from 'sql/workbench/services/accountManagement/browser/accountDialog';
export class AccountManagementService implements IAccountManagementService {
// CONSTANTS ///////////////////////////////////////////////////////////
@@ -36,6 +39,8 @@ export class AccountManagementService implements IAccountManagementService {
private _accountDialogController?: AccountDialogController;
private _autoOAuthDialogController?: AutoOAuthDialogController;
private _mementoContext?: Memento;
protected readonly disposables = new DisposableStore();
private readonly configurationService: IConfigurationService;
// EVENT EMITTERS //////////////////////////////////////////////////////
private _addAccountProviderEmitter: Emitter<AccountProviderAddedEventParams>;
@@ -54,7 +59,8 @@ export class AccountManagementService implements IAccountManagementService {
@IClipboardService private _clipboardService: IClipboardService,
@IOpenerService private _openerService: IOpenerService,
@ILogService private readonly _logService: ILogService,
@INotificationService private readonly _notificationService: INotificationService
@INotificationService private readonly _notificationService: INotificationService,
@IConfigurationService configurationService: IConfigurationService
) {
this._mementoContext = new Memento(AccountManagementService.ACCOUNT_MEMENTO, this._storageService);
const mementoObj = this._mementoContext.getMemento(StorageScope.GLOBAL, StorageTarget.MACHINE);
@@ -64,8 +70,10 @@ export class AccountManagementService implements IAccountManagementService {
this._addAccountProviderEmitter = new Emitter<AccountProviderAddedEventParams>();
this._removeAccountProviderEmitter = new Emitter<azdata.AccountProviderMetadata>();
this._updateAccountListEmitter = new Emitter<UpdateAccountListEventParams>();
this.configurationService = configurationService;
_storageService.onWillSaveState(() => this.shutdown());
this.registerListeners();
}
private get autoOAuthDialogController(): AutoOAuthDialogController {
@@ -136,6 +144,10 @@ export class AccountManagementService implements IAccountManagementService {
}
let result = await this._accountStore.addOrUpdate(account);
if (!result) {
this._logService.error('adding account failed');
throw Error('Adding account failed, check Azure Accounts log for more info.')
}
if (result.accountAdded) {
// Add the account to the list
provider.accounts.push(result.changedAccount);
@@ -458,10 +470,15 @@ export class AccountManagementService implements IAccountManagementService {
});
}
const authLibrary: AuthLibrary = this.configurationService.getValue('azure.authenticationLibrary');
let updatedAccounts: azdata.Account[]
if (authLibrary) {
updatedAccounts = filterAccounts(provider.accounts, authLibrary);
}
// Step 2) Fire the event
let eventArg: UpdateAccountListEventParams = {
providerId: provider.metadata.id,
accountList: provider.accounts
accountList: updatedAccounts ?? provider.accounts
};
this._updateAccountListEmitter.fire(eventArg);
}
@@ -475,6 +492,39 @@ export class AccountManagementService implements IAccountManagementService {
provider.accounts.splice(indexToRemove, 1, modifiedAccount);
}
}
private registerListeners(): void {
this.disposables.add(this.configurationService.onDidChangeConfiguration(async e => {
if (e.affectsConfiguration('azure.authenticationLibrary')) {
const authLibrary: AuthLibrary = this.configurationService.getValue('azure.authenticationLibrary');
if (authLibrary) {
let accounts = await this._accountStore.getAllAccounts();
if (accounts) {
let updatedAccounts = filterAccounts(accounts, authLibrary);
let eventArg: UpdateAccountListEventParams;
if (updatedAccounts.length > 0) {
updatedAccounts.forEach(account => {
if (account.key.authLibrary === 'MSAL') {
account.isStale = false;
}
});
eventArg = {
providerId: updatedAccounts[0].key.providerId,
accountList: updatedAccounts
};
} else { // default to public cloud if no accounts
eventArg = {
providerId: 'azure_publicCloud',
accountList: updatedAccounts
};
}
this._updateAccountListEmitter.fire(eventArg);
}
}
}
}));
}
}
/**

View File

@@ -18,6 +18,7 @@ import { EventVerifierSingle } from 'sql/base/test/common/event';
import { TestNotificationService } from 'vs/platform/notification/test/common/testNotificationService';
import { AccountDialog } from 'sql/workbench/services/accountManagement/browser/accountDialog';
import { Emitter } from 'vs/base/common/event';
import { TestConfigurationService } from 'sql/platform/connection/test/common/testConfigurationService';
// SUITE CONSTANTS /////////////////////////////////////////////////////////
const hasAccountProvider: azdata.AccountProviderMetadata = {
@@ -530,9 +531,10 @@ function getTestState(): AccountManagementState {
.returns(() => mockAccountStore.object);
const testNotificationService = new TestNotificationService();
const testConfigurationService = new TestConfigurationService();
// Create the account management service
let ams = new AccountManagementService(mockInstantiationService.object, new TestStorageService(), undefined!, undefined!, undefined!, testNotificationService);
let ams = new AccountManagementService(mockInstantiationService.object, new TestStorageService(), undefined!, undefined!, undefined!, testNotificationService, testConfigurationService);
// Wire up event handlers
let evUpdate = new EventVerifierSingle<UpdateAccountListEventParams>();

View File

@@ -26,6 +26,7 @@ import { ILayoutService } from 'vs/platform/layout/browser/layoutService';
import { ConnectionWidget } from 'sql/workbench/services/connection/browser/connectionWidget';
import { ILogService } from 'vs/platform/log/common/log';
import { IErrorMessageService } from 'sql/platform/errorMessage/common/errorMessageService';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
/**
* Connection Widget clas for CMS Connections
@@ -47,8 +48,9 @@ export class CmsConnectionWidget extends ConnectionWidget {
@IAccountManagementService _accountManagementService: IAccountManagementService,
@ILogService _logService: ILogService,
@IErrorMessageService _errorMessageService: IErrorMessageService,
@IConfigurationService configurationService: IConfigurationService
) {
super(options, callbacks, providerName, _themeService, _contextViewService, _connectionManagementService, _accountManagementService, _logService, _errorMessageService);
super(options, callbacks, providerName, _themeService, _contextViewService, _connectionManagementService, _accountManagementService, _logService, _errorMessageService, configurationService);
let authTypeOption = this._optionsMaps[ConnectionOptionSpecialType.authType];
if (authTypeOption) {
let authTypeDefault = this.getAuthTypeDefault(authTypeOption, OS);

View File

@@ -27,7 +27,6 @@ import { AzureResource, ConnectionOptionSpecialType } from 'sql/workbench/api/co
import { IAccountManagementService } from 'sql/platform/accounts/common/interfaces';
import * as azdata from 'azdata';
import * as nls from 'vs/nls';
import * as errors from 'vs/base/common/errors';
import { Disposable } from 'vs/base/common/lifecycle';

View File

@@ -36,6 +36,8 @@ import Severity from 'vs/base/common/severity';
import { ConnectionStringOptions } from 'sql/platform/capabilities/common/capabilitiesService';
import { isFalsyOrWhitespace } from 'vs/base/common/strings';
import { AuthenticationType } from 'sql/platform/connection/common/constants';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
import { AuthLibrary, filterAccounts } from 'sql/workbench/services/accountManagement/browser/accountDialog';
const ConnectionStringText = localize('connectionWidget.connectionString', "Connection string");
@@ -107,6 +109,7 @@ export class ConnectionWidget extends lifecycle.Disposable {
color: undefined,
description: undefined,
};
private readonly configurationService: IConfigurationService;
constructor(options: azdata.ConnectionOption[],
callbacks: IConnectionComponentCallbacks,
providerName: string,
@@ -115,7 +118,8 @@ export class ConnectionWidget extends lifecycle.Disposable {
@IConnectionManagementService private _connectionManagementService: IConnectionManagementService,
@IAccountManagementService private _accountManagementService: IAccountManagementService,
@ILogService protected _logService: ILogService,
@IErrorMessageService private _errorMessageService: IErrorMessageService
@IErrorMessageService private _errorMessageService: IErrorMessageService,
@IConfigurationService configurationService: IConfigurationService
) {
super();
this._callbacks = callbacks;
@@ -135,6 +139,7 @@ export class ConnectionWidget extends lifecycle.Disposable {
}
this._providerName = providerName;
this._connectionStringOptions = this._connectionManagementService.getProviderProperties(this._providerName).connectionStringOptions;
this.configurationService = configurationService;
}
protected getAuthTypeDefault(option: azdata.ConnectionOption, os: OperatingSystem): string {
@@ -591,7 +596,12 @@ export class ConnectionWidget extends lifecycle.Disposable {
private async fillInAzureAccountOptions(): Promise<void> {
let oldSelection = this._azureAccountDropdown.value;
const accounts = await this._accountManagementService.getAccounts();
this._azureAccountList = accounts.filter(a => a.key.providerId.startsWith('azure'));
const updatedAccounts = accounts.filter(a => a.key.providerId.startsWith('azure'));
const authLibrary: AuthLibrary = this.configurationService.getValue('azure.authenticationLibrary');
if (authLibrary) {
this._azureAccountList = filterAccounts(updatedAccounts, authLibrary);
}
let accountDropdownOptions: SelectOptionItemSQL[] = this._azureAccountList.map(account => {
return {
text: account.displayInfo.displayName,