Fix race condition in account management (#14218)

This commit is contained in:
Charles Gagnon
2021-02-09 22:40:46 -08:00
committed by GitHub
parent a226c75d8c
commit d78d89e326

View File

@@ -15,11 +15,12 @@ import { IMainContext } from 'vs/workbench/api/common/extHost.protocol';
import { Event, Emitter } from 'vs/base/common/event'; import { Event, Emitter } from 'vs/base/common/event';
import { values } from 'vs/base/common/collections'; import { values } from 'vs/base/common/collections';
type ProviderAndAccount = { provider: azdata.AccountProvider, account: azdata.Account };
export class ExtHostAccountManagement extends ExtHostAccountManagementShape { export class ExtHostAccountManagement extends ExtHostAccountManagementShape {
private _handlePool: number = 0; private _handlePool: number = 0;
private _proxy: MainThreadAccountManagementShape; private _proxy: MainThreadAccountManagementShape;
private _providers: { [handle: number]: AccountProviderWithMetadata } = {}; private _providers: { [handle: number]: AccountProviderWithMetadata } = {};
private _accounts: { [handle: number]: azdata.Account[] } = {};
private readonly _onDidChangeAccounts = new Emitter<azdata.DidChangeAccountsParams>(); private readonly _onDidChangeAccounts = new Emitter<azdata.DidChangeAccountsParams>();
constructor(mainContext: IMainContext) { constructor(mainContext: IMainContext) {
@@ -62,14 +63,19 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape {
this._proxy.$accountUpdated(updatedAccount); this._proxy.$accountUpdated(updatedAccount);
} }
public $getAllAccounts(): Thenable<azdata.Account[]> { public $getAllAccounts(): Thenable<azdata.Account[]> {
return this.getAllProvidersAndAccounts().then(providersAndAccounts => {
return providersAndAccounts.map(providerAndAccount => providerAndAccount.account);
});
}
private async getAllProvidersAndAccounts(): Promise<ProviderAndAccount[]> {
if (Object.keys(this._providers).length === 0) { if (Object.keys(this._providers).length === 0) {
throw new Error('No account providers registered.'); throw new Error('No account providers registered.');
} }
this._accounts = {}; const resultProviderAndAccounts: ProviderAndAccount[] = [];
const resultAccounts: azdata.Account[] = [];
const promises: Thenable<void>[] = []; const promises: Thenable<void>[] = [];
@@ -79,37 +85,31 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape {
const provider = this._providers[providerHandle]; const provider = this._providers[providerHandle];
promises.push(this._proxy.$getAccountsForProvider(provider.metadata.id).then( promises.push(this._proxy.$getAccountsForProvider(provider.metadata.id).then(
(accounts) => { (accounts) => {
this._accounts[providerHandle] = accounts; resultProviderAndAccounts.push(...accounts.map(account => { return { provider: provider.provider, account }; }));
resultAccounts.push(...accounts);
} }
)); ));
} }
return Promise.all(promises).then(() => resultAccounts); await Promise.all(promises);
return resultProviderAndAccounts;
} }
public $getSecurityToken(account: azdata.Account, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{}> { public $getSecurityToken(account: azdata.Account, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{}> {
return this.$getAllAccounts().then(() => { return this.getAllProvidersAndAccounts().then(providerAndAccounts => {
for (const handle in this._accounts) { const providerAndAccount = providerAndAccounts.find(providerAndAccount => providerAndAccount.account.key.accountId === account.key.accountId);
const providerHandle = parseInt(handle); if (providerAndAccount) {
if (this._accounts[handle].findIndex((acct) => acct.key.accountId === account.key.accountId) !== -1) { return providerAndAccount.provider.getSecurityToken(account, resource);
return this._withProvider(providerHandle, (provider: azdata.AccountProvider) => provider.getSecurityToken(account, resource));
} }
}
throw new Error(`Account ${account.key.accountId} not found.`); throw new Error(`Account ${account.key.accountId} not found.`);
}); });
} }
public $getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{ token: string }> { public $getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{ token: string }> {
return this.$getAllAccounts().then(() => { return this.getAllProvidersAndAccounts().then(providerAndAccounts => {
for (const handle in this._accounts) { const providerAndAccount = providerAndAccounts.find(providerAndAccount => providerAndAccount.account.key.accountId === account.key.accountId);
const providerHandle = parseInt(handle); if (providerAndAccount) {
if (this._accounts[handle].findIndex((acct) => acct.key.accountId === account.key.accountId) !== -1) { return providerAndAccount.provider.getAccountSecurityToken(account, tenant, resource);
return this._withProvider(providerHandle, (provider: azdata.AccountProvider) => provider.getAccountSecurityToken(account, tenant, resource));
} }
}
throw new Error(`Account ${account.key.accountId} not found.`); throw new Error(`Account ${account.key.accountId} not found.`);
}); });
} }