Large cleanup of AzureCore - Introduction of getAccountSecurityToken and deprecation of getSecurityToken (#11446)

* do a large cleanup of azurecore

* Fix tests

* Rework Device Code

* Fix tests

* Fix AE scenario

* Fix firewall rule - clenaup logging

* Shorthand syntax

* Fix firewall tests

* Start on tests for azureAuth

* Add more tests

* Address comments

* Add a few more important tests

* Don't throw error on old code

* Fill in todo
This commit is contained in:
Amir Omidi
2020-07-22 15:03:42 -07:00
committed by GitHub
parent a61b85c9ff
commit 587abd43c2
40 changed files with 1045 additions and 895 deletions

View File

@@ -3,104 +3,39 @@
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as nls from 'vscode-nls';
import axios, { AxiosResponse, AxiosRequestConfig } from 'axios';
import * as qs from 'qs';
import * as url from 'url';
import {
AzureAccountProviderMetadata,
Tenant,
AzureAccount,
Resource,
AzureAccountProviderMetadata,
AzureAuthType,
Subscription,
Deferred
Deferred,
Resource,
Tenant
} from '../interfaces';
import * as url from 'url';
import { SimpleTokenCache } from '../simpleTokenCache';
import { MemoryDatabase } from '../utils/memoryDatabase';
import axios, { AxiosRequestConfig, AxiosResponse } from 'axios';
import { Logger } from '../../utils/Logger';
import * as qs from 'qs';
import { AzureAuthError } from './azureAuthError';
const localize = nls.loadMessageBundle();
export interface AccountKey {
/**
* Account Key - uniquely identifies an account
*/
key: string
}
export interface AccessToken extends AccountKey {
/**
* Access Token
*/
token: string;
}
export interface RefreshToken extends AccountKey {
/**
* Refresh Token
*/
token: string;
/**
* Account Key
*/
key: string
}
export interface TokenResponse {
[tenantId: string]: Token
}
export interface Token extends AccountKey {
/**
* Access token
*/
token: string;
/**
* TokenType
*/
tokenType: string;
}
export interface TokenClaims { // https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens
aud: string;
iss: string;
iat: number;
idp: string,
nbf: number;
exp: number;
c_hash: string;
at_hash: string;
aio: string;
preferred_username: string;
email: string;
name: string;
nonce: string;
oid: string;
roles: string[];
rh: string;
sub: string;
tid: string;
unique_name: string;
uti: string;
ver: string;
}
export type TokenRefreshResponse = { accessToken: AccessToken, refreshToken: RefreshToken, tokenClaims: TokenClaims, expiresOn: string };
export abstract class AzureAuth implements vscode.Disposable {
protected readonly memdb = new MemoryDatabase();
protected readonly memdb = new MemoryDatabase<string>();
protected readonly WorkSchoolAccountType: string = 'work_school';
protected readonly MicrosoftAccountType: string = 'microsoft';
protected readonly loginEndpointUrl: string;
protected readonly commonTenant: Tenant;
public readonly commonTenant: Tenant;
protected readonly redirectUri: string;
protected readonly scopes: string[];
protected readonly scopesString: string;
@@ -141,185 +76,227 @@ export abstract class AzureAuth implements vscode.Disposable {
this.scopesString = this.scopes.join(' ');
}
public abstract async login(): Promise<AzureAccount | azdata.PromptFailedResult>;
public abstract async autoOAuthCancelled(): Promise<void>;
public abstract async promptForConsent(resourceId: string, tenant: string): Promise<{ tokenRefreshResponse: TokenRefreshResponse, authCompleteDeferred: Deferred<void> } | undefined>;
public dispose() { }
public async refreshAccess(oldAccount: azdata.Account): Promise<azdata.Account> {
const response = await this.getCachedToken(oldAccount.key);
if (!response) {
oldAccount.isStale = true;
return oldAccount;
}
const refreshToken = response.refreshToken;
if (!refreshToken || !refreshToken.key) {
oldAccount.isStale = true;
return oldAccount;
}
public async startLogin(): Promise<AzureAccount | azdata.PromptFailedResult> {
let loginComplete: Deferred<void>;
try {
// Refresh the access token
const tokenResponse = await this.refreshAccessToken(oldAccount.key, refreshToken);
const tenants = await this.getTenants(tokenResponse.accessToken);
// Recreate account object
const newAccount = this.createAccount(tokenResponse.tokenClaims, tokenResponse.accessToken.key, tenants);
const subscriptions = await this.getSubscriptions(newAccount);
newAccount.properties.subscriptions = subscriptions;
return newAccount;
const result = await this.login(this.commonTenant, this.metadata.settings.microsoftResource);
loginComplete = result.authComplete;
if (!result?.response) {
Logger.error('Authentication failed');
return {
canceled: false
};
}
const account = await this.hydrateAccount(result.response.accessToken, result.response.tokenClaims);
loginComplete?.resolve();
return account;
} catch (ex) {
oldAccount.isStale = true;
if (ex.message) {
await vscode.window.showErrorMessage(ex.message);
if (ex instanceof AzureAuthError) {
if (loginComplete) {
loginComplete.reject(ex.getPrintableString());
} else {
vscode.window.showErrorMessage(ex.getPrintableString());
}
}
Logger.error(ex);
return undefined;
}
return oldAccount;
}
private getHomeTenant(account: AzureAccount): Tenant {
// Home is defined by the API
// Lets pick the home tenant - and fall back to commonTenant if they don't exist
return account.properties.tenants.find(t => t.tenantCategory === 'Home') ?? account.properties.tenants[0] ?? this.commonTenant;
}
public async getSecurityToken(account: azdata.Account, azureResource: azdata.AzureResource): Promise<TokenResponse | undefined> {
public async refreshAccess(account: AzureAccount): Promise<AzureAccount> {
try {
const tenant = this.getHomeTenant(account);
const tokenResult = await this.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.MicrosoftResourceManagement);
if (!tokenResult) {
account.isStale = true;
return account;
}
return await this.hydrateAccount(tokenResult, this.getTokenClaims(tokenResult.token));
} catch (ex) {
if (ex instanceof AzureAuthError) {
vscode.window.showErrorMessage(ex.getPrintableString());
}
Logger.error(ex);
account.isStale = true;
return account;
}
}
public async hydrateAccount(token: Token | AccessToken, tokenClaims: TokenClaims): Promise<AzureAccount> {
const tenants = await this.getTenants({ ...token });
const account = this.createAccount(tokenClaims, token.key, tenants);
return account;
}
public async getAccountSecurityToken(account: AzureAccount, tenantId: string, azureResource: azdata.AzureResource): Promise<Token | undefined> {
if (account.isStale === true) {
Logger.log('Account was stale, no tokens being fetched');
Logger.log('Account was stale. No tokens being fetched.');
return undefined;
}
const resource = this.resources.find(s => s.azureResourceId === azureResource);
if (!resource) {
Logger.log('Invalid resource, not fetching', azureResource);
return undefined;
}
const azureAccount = account as AzureAccount;
const response: TokenResponse = {};
const tenant = account.properties.tenants.find(t => t.id === tenantId);
for (const tenant of azureAccount.properties.tenants) {
let cachedTokens = await this.getCachedToken(account.key, resource.id, tenant.id);
// Check expiration
if (cachedTokens) {
const expiresOn = Number(this.memdb.get(this.createMemdbString(account.key.accountId, tenant.id, resource.id)));
const currentTime = new Date().getTime() / 1000;
if (!tenant) {
throw new AzureAuthError(localize('azure.tenantNotFound', "Specifed tenant with ID '{0}' not found.", tenantId), `Tenant ${tenantId} not found.`, undefined);
}
if (!Number.isNaN(expiresOn)) {
const remainingTime = expiresOn - currentTime;
const fiveMinutes = 5 * 60;
// If the remaining time is less than five minutes, assume the token has expired. It's too close to expiration to be meaningful.
if (remainingTime < fiveMinutes) {
cachedTokens = undefined;
}
} else {
// No expiration date, assume expired.
cachedTokens = undefined;
Logger.log('Assuming expired token due to no expiration date - this is expected on first launch.');
}
const cachedTokens = await this.getSavedToken(tenant, resource, account.key);
// Let's check to see if we can just use the cached tokens to return to the user
if (cachedTokens?.accessToken) {
let expiry = Number(cachedTokens.expiresOn);
if (Number.isNaN(expiry)) {
Logger.log('Expiration time was not defined. This is expected on first launch');
expiry = 0;
}
const currentTime = new Date().getTime() / 1000;
// Refresh
if (!cachedTokens) {
let accessToken = cachedTokens.accessToken;
const remainingTime = expiry - currentTime;
const maxTolerance = 2 * 60; // two minutes
const baseToken = await this.getCachedToken(account.key);
if (!baseToken) {
account.isStale = true;
Logger.log('Base token was empty, account is stale.');
return undefined;
}
try {
await this.refreshAccessToken(account.key, baseToken.refreshToken, tenant, resource);
} catch (ex) {
Logger.log(`Could not refresh access token for ${JSON.stringify(tenant)} - silently removing the tenant from the user's account.`);
Logger.error(`Actual error: ${JSON.stringify(ex?.response?.data ?? ex.message ?? ex, undefined, 2)}`);
azureAccount.properties.tenants = azureAccount.properties.tenants.filter(t => t.id !== tenant.id);
continue;
}
cachedTokens = await this.getCachedToken(account.key, resource.id, tenant.id);
if (!cachedTokens) {
Logger.log('Refresh access tokens didn not set cache');
return undefined;
}
if (remainingTime < maxTolerance) {
const result = await this.refreshToken(tenant, resource, cachedTokens.refreshToken);
accessToken = result.accessToken;
}
const { accessToken } = cachedTokens;
response[tenant.id] = {
token: accessToken.token,
key: accessToken.key,
// Let's just return here.
if (accessToken) {
return {
...accessToken,
tokenType: 'Bearer'
};
}
}
// User didn't have any cached tokens, or the cached tokens weren't useful.
// For most users we can use the refresh token from the general microsoft resource to an access token of basically any type of resource we want.
const baseTokens = await this.getSavedToken(this.commonTenant, this.metadata.settings.microsoftResource, account.key);
if (!baseTokens) {
Logger.error('User had no base tokens for the basic resource registered. This should not happen and indicates something went wrong with the authentication cycle');
const msg = localize('azure.noBaseToken', 'Something failed with the authentication, or your tokens have been deleted from the system. Please try adding your account to Azure Data Studio again.');
account.isStale = true;
throw new AzureAuthError(msg, 'No base token found', undefined);
}
// Let's try to convert the access token type, worst case we'll have to prompt the user to do an interactive authentication.
const result = await this.refreshToken(tenant, resource, baseTokens.refreshToken);
if (result.accessToken) {
return {
...result.accessToken,
tokenType: 'Bearer'
};
}
if (azureAccount.properties.subscriptions) {
azureAccount.properties.subscriptions.forEach(subscription => {
// Make sure that tenant has information populated.
if (response[subscription.tenantId]) {
response[subscription.id] = {
...response[subscription.tenantId]
};
}
});
}
return response;
return undefined;
}
public async clearCredentials(account: azdata.AccountKey): Promise<void> {
try {
return this.deleteAccountCache(account);
} catch (ex) {
const msg = localize('azure.cacheErrrorRemove', "Error when removing your account from the cache.");
vscode.window.showErrorMessage(msg);
Logger.error('Error when removing tokens.', ex);
}
}
protected toBase64UrlEncoding(base64string: string): string {
return base64string.replace(/=/g, '').replace(/\+/g, '-').replace(/\//g, '_'); // Need to use base64url encoding
}
protected async makePostRequest(uri: string, postData: { [key: string]: string }, validateStatus = false) {
try {
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
}
protected abstract async login(tenant: Tenant, resource: Resource): Promise<{ response: OAuthTokenResponse, authComplete: Deferred<void> }>;
/**
* Refreshes a token, if a refreshToken is passed in then we use that. If it is not passed in then we will prompt the user for consent.
* @param tenant
* @param resource
* @param refreshToken
*/
public async refreshToken(tenant: Tenant, resource: Resource, refreshToken: RefreshToken | undefined): Promise<OAuthTokenResponse> {
if (refreshToken) {
const postData: RefreshTokenPostData = {
grant_type: 'refresh_token',
client_id: this.clientId,
refresh_token: refreshToken.token,
tenant: tenant.id,
resource: resource.endpoint
};
if (validateStatus) {
config.validateStatus = () => true;
}
return await axios.post(uri, qs.stringify(postData), config);
} catch (ex) {
Logger.log('Unexpected error making Azure auth request', 'azureCore.postRequest', JSON.stringify(ex?.response?.data, undefined, 2));
throw ex;
return this.getToken(tenant, resource, postData);
}
return this.handleInteractionRequired(tenant, resource);
}
protected async makeGetRequest(token: string, uri: string): Promise<AxiosResponse<any>> {
try {
const config = {
headers: {
Authorization: `Bearer ${token}`,
'Content-Type': 'application/json',
},
public async getToken(tenant: Tenant, resource: Resource, postData: AuthorizationCodePostData | TokenPostData | RefreshTokenPostData): Promise<OAuthTokenResponse> {
const tokenUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/token`;
const response = await this.makePostRequest(tokenUrl, postData);
if (response.data.error === 'interaction_required') {
return this.handleInteractionRequired(tenant, resource);
}
if (response.data.error) {
Logger.error('Response error!', response.data);
throw new AzureAuthError(localize('azure.responseError', "Token retrival failed with an error. Open developer tools to view the error"), 'Token retrival failed', undefined);
}
const accessTokenString = response.data.access_token;
const refreshTokenString = response.data.refresh_token;
const expiresOnString = response.data.expires_on;
return this.getTokenHelper(tenant, resource, accessTokenString, refreshTokenString, expiresOnString);
}
public async getTokenHelper(tenant: Tenant, resource: Resource, accessTokenString: string, refreshTokenString: string, expiresOnString: string): Promise<OAuthTokenResponse> {
if (!accessTokenString) {
const msg = localize('azure.accessTokenEmpty', 'No access token returned from Microsoft OAuth');
throw new AzureAuthError(msg, 'Access token was empty', undefined);
}
const tokenClaims: TokenClaims = this.getTokenClaims(accessTokenString);
const userKey = tokenClaims.sub ?? tokenClaims.oid;
if (!userKey) {
const msg = localize('azure.noUniqueIdentifier', "The user had no unique identifier within AAD");
throw new AzureAuthError(msg, 'No unique identifier', undefined);
}
const accessToken: AccessToken = {
token: accessTokenString,
key: userKey
};
let refreshToken: RefreshToken;
if (refreshTokenString) {
refreshToken = {
token: refreshTokenString,
key: userKey
};
return await axios.get(uri, config);
} catch (ex) {
// Intercept and print error
Logger.log('Unexpected error making Azure auth request', 'azureCore.getRequest', JSON.stringify(ex?.response?.data, undefined, 2));
// rethrow error
throw ex;
}
const result: OAuthTokenResponse = {
accessToken,
refreshToken,
tokenClaims,
expiresOn: expiresOnString
};
const accountKey: azdata.AccountKey = {
providerId: this.metadata.id,
accountId: userKey
};
await this.saveToken(tenant, resource, accountKey, result);
return result;
}
protected async getTenants(token: AccessToken): Promise<Tenant[]> {
//#region tenant calls
public async getTenants(token: AccessToken): Promise<Tenant[]> {
interface TenantResponse { // https://docs.microsoft.com/en-us/rest/api/resources/tenants/list
id: string
tenantId: string
@@ -329,7 +306,7 @@ export abstract class AzureAuth implements vscode.Disposable {
const tenantUri = url.resolve(this.metadata.settings.armResource.endpoint, 'tenants?api-version=2019-11-01');
try {
const tenantResponse = await this.makeGetRequest(token.token, tenantUri);
const tenantResponse = await this.makeGetRequest(tenantUri, token.token);
Logger.pii('getTenants', tenantResponse.data);
const tenants: Tenant[] = tenantResponse.data.value.map((tenantInfo: TenantResponse) => {
return {
@@ -353,97 +330,88 @@ export abstract class AzureAuth implements vscode.Disposable {
}
}
protected async getSubscriptions(account: AzureAccount): Promise<Subscription[]> {
interface SubscriptionResponse { // https://docs.microsoft.com/en-us/rest/api/resources/subscriptions/list
subscriptionId: string
tenantId: string
displayName: string
}
const allSubs: Subscription[] = [];
const tokens = await this.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
if (!tokens) {
Logger.log('There were no resource management tokens to retrieve subscriptions from. Account is stale.');
account.isStale = true;
}
//#endregion
for (const tenant of account.properties.tenants) {
const token = tokens[tenant.id];
const subscriptionUri = url.resolve(this.metadata.settings.armResource.endpoint, 'subscriptions?api-version=2019-11-01');
try {
const subscriptionResponse = await this.makeGetRequest(token.token, subscriptionUri);
Logger.pii('getSubscriptions', subscriptionResponse.data);
const subscriptions: Subscription[] = subscriptionResponse.data.value.map((subscriptionInfo: SubscriptionResponse) => {
return {
id: subscriptionInfo.subscriptionId,
displayName: subscriptionInfo.displayName,
tenantId: subscriptionInfo.tenantId
} as Subscription;
});
allSubs.push(...subscriptions);
} catch (ex) {
Logger.error(ex);
throw new Error('Error retrieving subscription information');
}
//#region token management
private async saveToken(tenant: Tenant, resource: Resource, accountKey: azdata.AccountKey, { accessToken, refreshToken, expiresOn }: OAuthTokenResponse) {
const msg = localize('azure.cacheErrorAdd', "Error when adding your account to the cache.");
if (!tenant.id || !resource.id) {
Logger.pii('Tenant ID or resource ID was undefined', tenant, resource);
throw new AzureAuthError(msg, 'Adding account to cache failed', undefined);
}
return allSubs;
}
protected async getToken(postData: { [key: string]: string }, tenant = this.commonTenant, resourceId: string = '', resourceEndpoint: string = ''): Promise<TokenRefreshResponse | undefined> {
try {
let refreshResponse: TokenRefreshResponse;
try {
const tokenUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/token`;
const tokenResponse = await this.makePostRequest(tokenUrl, postData);
Logger.pii(JSON.stringify(tokenResponse.data));
const tokenClaims = this.getTokenClaims(tokenResponse.data.access_token);
const accessToken: AccessToken = {
token: tokenResponse.data.access_token,
key: tokenClaims.oid ?? tokenClaims.email ?? tokenClaims.unique_name ?? tokenClaims.name,
};
const refreshToken: RefreshToken = {
token: tokenResponse.data.refresh_token,
key: accessToken.key
};
const expiresOn = tokenResponse.data.expires_on;
refreshResponse = { accessToken, refreshToken, tokenClaims, expiresOn };
} catch (ex) {
Logger.pii(JSON.stringify(ex?.response?.data));
if (ex?.response?.data?.error === 'interaction_required') {
const shouldOpenLink = await this.openConsentDialog(tenant, resourceId);
if (shouldOpenLink === true) {
const { tokenRefreshResponse, authCompleteDeferred } = await this.promptForConsent(resourceEndpoint, tenant.id);
refreshResponse = tokenRefreshResponse;
authCompleteDeferred.resolve();
} else {
vscode.window.showInformationMessage(localize('azure.noConsentToReauth', "The authentication failed since Azure Data Studio was unable to open re-authentication page."));
}
} else {
return undefined;
}
}
this.memdb.set(this.createMemdbString(refreshResponse.accessToken.key, tenant.id, resourceId), refreshResponse.expiresOn);
return refreshResponse;
} catch (err) {
const msg = localize('azure.noToken', "Retrieving the Azure token failed. Please sign in again.");
vscode.window.showErrorMessage(msg);
throw new Error(err);
await this.tokenCache.saveCredential(`${accountKey.accountId}_access_${resource.id}_${tenant.id}`, JSON.stringify(accessToken));
await this.tokenCache.saveCredential(`${accountKey.accountId}_refresh_${resource.id}_${tenant.id}`, JSON.stringify(refreshToken));
this.memdb.set(`${accountKey.accountId}_${tenant.id}_${resource.id}`, expiresOn);
} catch (ex) {
Logger.error(ex);
throw new AzureAuthError(msg, 'Adding account to cache failed', ex);
}
}
public async openConsentDialog(tenant: Tenant, resourceId: string): Promise<boolean> {
public async getSavedToken(tenant: Tenant, resource: Resource, accountKey: azdata.AccountKey): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> {
const getMsg = localize('azure.cacheErrorGet', "Error when getting your account from the cache");
const parseMsg = localize('azure.cacheErrorParse', "Error when parsing your account from the cache");
if (!tenant.id || !resource.id) {
Logger.pii('Tenant ID or resource ID was undefined', tenant, resource);
throw new AzureAuthError(getMsg, 'Getting account from cache failed', undefined);
}
let accessTokenString: string;
let refreshTokenString: string;
let expiresOn: string;
try {
accessTokenString = await this.tokenCache.getCredential(`${accountKey.accountId}_access_${resource.id}_${tenant.id}`);
refreshTokenString = await this.tokenCache.getCredential(`${accountKey.accountId}_refresh_${resource.id}_${tenant.id}`);
expiresOn = this.memdb.get(`${accountKey.accountId}_${tenant.id}_${resource.id}`);
} catch (ex) {
Logger.error(ex);
throw new AzureAuthError(getMsg, 'Getting account from cache failed', ex);
}
try {
if (!accessTokenString) {
return undefined;
}
const accessToken: AccessToken = JSON.parse(accessTokenString);
let refreshToken: RefreshToken;
if (refreshTokenString) {
refreshToken = JSON.parse(refreshTokenString);
}
return {
accessToken, refreshToken, expiresOn
};
} catch (ex) {
Logger.error(ex);
throw new AzureAuthError(parseMsg, 'Parsing account from cache failed', ex);
}
}
//#endregion
//#region interaction handling
public async handleInteractionRequired(tenant: Tenant, resource: Resource): Promise<OAuthTokenResponse | undefined> {
const shouldOpen = await this.askUserForInteraction(tenant, resource);
if (shouldOpen) {
const result = await this.login(tenant, resource);
result?.authComplete?.resolve();
return result?.response;
}
return undefined;
}
/**
* Asks the user if they would like to do the interaction based authentication as required by OAuth2
* @param tenant
* @param resource
*/
private async askUserForInteraction(tenant: Tenant, resource: Resource): Promise<boolean> {
if (!tenant.displayName && !tenant.id) {
throw new Error('Tenant did not have display name or id');
}
if (tenant.id === 'common') {
throw new Error('Common tenant should not need consent');
}
const getTenantConfigurationSet = (): Set<string> => {
const configuration = vscode.workspace.getConfiguration('azure.tenant.config');
let values: string[] = configuration.get('filter') ?? [];
@@ -486,7 +454,7 @@ export abstract class AzureAuth implements vscode.Disposable {
}
};
const messageBody = localize('azurecore.consentDialog.body', "Your tenant '{0} ({1})' requires you to re-authenticate again to access {2} resources. Press Open to start the authentication process.", tenant.displayName, tenant.id, resourceId);
const messageBody = localize('azurecore.consentDialog.body', "Your tenant '{0} ({1})' requires you to re-authenticate again to access {2} resources. Press Open to start the authentication process.", tenant.displayName, tenant.id, resource.id);
const result = await vscode.window.showInformationMessage(messageBody, { modal: true }, openItem, closeItem, dontAskAgainItem);
if (result.action) {
@@ -495,113 +463,9 @@ export abstract class AzureAuth implements vscode.Disposable {
return result.booleanResult;
}
//#endregion
protected getTokenClaims(accessToken: string): TokenClaims | undefined {
try {
const split = accessToken.split('.');
return JSON.parse(Buffer.from(split[1], 'base64').toString('binary'));
} catch (ex) {
throw new Error('Unable to read token claims: ' + JSON.stringify(ex));
}
}
private async refreshAccessToken(account: azdata.AccountKey, rt: RefreshToken, tenant: Tenant = this.commonTenant, resource?: Resource): Promise<TokenRefreshResponse> {
const postData: { [key: string]: string } = {
grant_type: 'refresh_token',
refresh_token: rt.token,
client_id: this.clientId,
tenant: tenant.id,
};
if (resource) {
postData.resource = resource.endpoint;
}
const getTokenResponse = await this.getToken(postData, tenant, resource?.id, resource?.endpoint);
const accessToken = getTokenResponse?.accessToken;
const refreshToken = getTokenResponse?.refreshToken;
if (!accessToken || !refreshToken) {
Logger.log('Access or refresh token were undefined');
const msg = localize('azure.refreshTokenError', "Error when refreshing your account.");
throw new Error(msg);
}
await this.setCachedToken(account, accessToken, refreshToken, resource?.id, tenant?.id);
return getTokenResponse;
}
public async setCachedToken(account: azdata.AccountKey, accessToken: AccessToken, refreshToken: RefreshToken, resourceId?: string, tenantId?: string): Promise<void> {
const msg = localize('azure.cacheErrorAdd', "Error when adding your account to the cache.");
resourceId = resourceId ?? '';
tenantId = tenantId ?? '';
if (!accessToken || !accessToken.token || !refreshToken.token || !accessToken.key) {
throw new Error(msg);
}
try {
await this.tokenCache.saveCredential(`${account.accountId}_access_${resourceId}_${tenantId}`, JSON.stringify(accessToken));
await this.tokenCache.saveCredential(`${account.accountId}_refresh_${resourceId}_${tenantId}`, JSON.stringify(refreshToken));
} catch (ex) {
Logger.error('Error when storing tokens.', ex);
throw new Error(msg);
}
}
public async getCachedToken(account: azdata.AccountKey, resourceId?: string, tenantId?: string): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken } | undefined> {
resourceId = resourceId ?? '';
tenantId = tenantId ?? '';
let accessToken: AccessToken;
let refreshToken: RefreshToken;
try {
accessToken = JSON.parse(await this.tokenCache.getCredential(`${account.accountId}_access_${resourceId}_${tenantId}`));
refreshToken = JSON.parse(await this.tokenCache.getCredential(`${account.accountId}_refresh_${resourceId}_${tenantId}`));
} catch (ex) {
return undefined;
}
if (!accessToken || !refreshToken) {
return undefined;
}
if (!refreshToken.token || !refreshToken.key) {
return undefined;
}
if (!accessToken.token || !accessToken.key) {
return undefined;
}
return {
accessToken,
refreshToken
};
}
public createMemdbString(accountKey: string, tenantId: string, resourceId: string): string {
return `${accountKey}_${tenantId}_${resourceId}`;
}
public async deleteAccountCache(account: azdata.AccountKey): Promise<void> {
const results = await this.tokenCache.findCredentials(account.accountId);
for (let { account } of results) {
await this.tokenCache.clearCredential(account);
}
}
public async deleteAllCache(): Promise<void> {
const results = await this.tokenCache.findCredentials('');
for (let { account } of results) {
await this.tokenCache.clearCredential(account);
}
}
//#region data modeling
public createAccount(tokenClaims: TokenClaims, key: string, tenants: Tenant[]): AzureAccount {
// Determine if this is a microsoft account
@@ -663,4 +527,184 @@ export abstract class AzureAuth implements vscode.Disposable {
return account;
}
//#endregion
//#region network functions
public async makePostRequest(url: string, postData: AuthorizationCodePostData | TokenPostData | DeviceCodeStartPostData | DeviceCodeCheckPostData): Promise<AxiosResponse<any>> {
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/x-www-form-urlencoded'
},
validateStatus: () => true // Never throw
};
// Intercept response and print out the response for future debugging
const response = await axios.post(url, qs.stringify(postData), config);
Logger.pii(url, postData, response.data);
return response;
}
private async makeGetRequest(url: string, token: string): Promise<AxiosResponse<any>> {
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${token}`
},
validateStatus: () => true // Never throw
};
const response = await axios.get(url, config);
Logger.pii(url, response.data);
return response;
}
//#endregion
//#region inconsequential
protected getTokenClaims(accessToken: string): TokenClaims | undefined {
try {
const split = accessToken.split('.');
return JSON.parse(Buffer.from(split[1], 'base64').toString('binary'));
} catch (ex) {
throw new Error('Unable to read token claims: ' + JSON.stringify(ex));
}
}
protected toBase64UrlEncoding(base64string: string): string {
return base64string.replace(/=/g, '').replace(/\+/g, '-').replace(/\//g, '_'); // Need to use base64url encoding
}
public async deleteAllCache(): Promise<void> {
const results = await this.tokenCache.findCredentials('');
for (let { account } of results) {
await this.tokenCache.clearCredential(account);
}
}
public async clearCredentials(account: azdata.AccountKey): Promise<void> {
try {
return this.deleteAccountCache(account);
} catch (ex) {
const msg = localize('azure.cacheErrrorRemove', "Error when removing your account from the cache.");
vscode.window.showErrorMessage(msg);
Logger.error('Error when removing tokens.', ex);
}
}
public async deleteAccountCache(account: azdata.AccountKey): Promise<void> {
const results = await this.tokenCache.findCredentials(account.accountId);
for (let { account } of results) {
await this.tokenCache.clearCredential(account);
}
}
public async dispose() { }
public async autoOAuthCancelled(): Promise<void> { }
//#endregion
}
//#region models
export interface AccountKey {
/**
* Account Key - uniquely identifies an account
*/
key: string
}
export interface AccessToken extends AccountKey {
/**
* Access Token
*/
token: string;
}
export interface RefreshToken extends AccountKey {
/**
* Refresh Token
*/
token: string;
/**
* Account Key
*/
key: string
}
export interface MultiTenantTokenResponse {
[tenantId: string]: Token
}
export interface Token extends AccountKey {
/**
* Access token
*/
token: string;
/**
* TokenType
*/
tokenType: string;
}
export interface TokenClaims { // https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens
aud: string;
iss: string;
iat: number;
idp: string,
nbf: number;
exp: number;
c_hash: string;
at_hash: string;
aio: string;
preferred_username: string;
email: string;
name: string;
nonce: string;
oid: string;
roles: string[];
rh: string;
sub: string;
tid: string;
unique_name: string;
uti: string;
ver: string;
}
export type OAuthTokenResponse = { accessToken: AccessToken, refreshToken: RefreshToken, tokenClaims: TokenClaims, expiresOn: string };
export interface TokenPostData {
grant_type: 'refresh_token' | 'authorization_code' | 'urn:ietf:params:oauth:grant-type:device_code';
client_id: string;
resource: string;
}
export interface RefreshTokenPostData extends TokenPostData {
grant_type: 'refresh_token';
refresh_token: string;
client_id: string;
tenant: string
}
export interface AuthorizationCodePostData extends TokenPostData {
grant_type: 'authorization_code';
code: string;
code_verifier: string;
redirect_uri: string;
}
export interface DeviceCodeStartPostData extends Omit<TokenPostData, 'grant_type'> {
}
export interface DeviceCodeCheckPostData extends Omit<TokenPostData, 'resource'> {
grant_type: 'urn:ietf:params:oauth:grant-type:device_code',
tenant: string,
code: string
}
//#endregion

View File

@@ -3,49 +3,37 @@
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { AuthorizationCodePostData, AzureAuth, OAuthTokenResponse } from './azureAuth';
import { AzureAccountProviderMetadata, AzureAuthType, Deferred, Resource, Tenant } from '../interfaces';
import * as vscode from 'vscode';
import * as crypto from 'crypto';
import { SimpleTokenCache } from '../simpleTokenCache';
import { SimpleWebServer } from '../utils/simpleWebServer';
import { AzureAuthError } from './azureAuthError';
import { Logger } from '../../utils/Logger';
import * as nls from 'vscode-nls';
import { promises as fs } from 'fs';
import * as path from 'path';
import * as http from 'http';
import * as qs from 'qs';
import { promises as fs } from 'fs';
import {
AzureAuth,
AccessToken,
RefreshToken,
TokenClaims,
TokenRefreshResponse,
} from './azureAuth';
import {
AzureAccountProviderMetadata,
AzureAuthType,
Deferred
} from '../interfaces';
import { SimpleWebServer } from '../utils/simpleWebServer';
import { SimpleTokenCache } from '../simpleTokenCache';
import { Logger } from '../../utils/Logger';
const localize = nls.loadMessageBundle();
function parseQuery(uri: vscode.Uri) {
return uri.query.split('&').reduce((prev: any, current) => {
const queryString = current.split('=');
prev[queryString[0]] = queryString[1];
return prev;
}, {});
interface AuthCodeResponse {
authCode: string;
codeVerifier: string;
redirectUri: string;
}
interface AuthCodeResponse {
authCode: string,
codeVerifier: string
interface CryptoValues {
nonce: string;
codeVerifier: string;
codeChallenge: string;
}
export class AzureAuthCodeGrant extends AzureAuth {
private static readonly USER_FRIENDLY_NAME: string = localize('azure.azureAuthCodeGrantName', "Azure Auth Code Grant");
private static readonly USER_FRIENDLY_NAME: string = localize('azure.azureAuthCodeGrantName', 'Azure Auth Code Grant');
private server: SimpleWebServer;
constructor(
@@ -57,105 +45,52 @@ export class AzureAuthCodeGrant extends AzureAuth {
super(metadata, tokenCache, context, uriEventEmitter, AzureAuthType.AuthCodeGrant, AzureAuthCodeGrant.USER_FRIENDLY_NAME);
}
public async promptForConsent(resourceEndpoint: string, tenant: string = this.commonTenant.id): Promise<{ tokenRefreshResponse: TokenRefreshResponse, authCompleteDeferred: Deferred<void> } | undefined> {
protected async login(tenant: Tenant, resource: Resource): Promise<{ response: OAuthTokenResponse, authComplete: Deferred<void> }> {
let authCompleteDeferred: Deferred<void>;
let authCompletePromise = new Promise<void>((resolve, reject) => authCompleteDeferred = { resolve, reject });
let authResponse: AuthCodeResponse;
if (vscode.env.uiKind === vscode.UIKind.Web) {
authResponse = await this.loginWithoutLocalServer(resourceEndpoint, tenant);
authResponse = await this.loginWeb(tenant, resource);
} else {
authResponse = await this.loginWithLocalServer(authCompletePromise, resourceEndpoint, tenant);
}
let tokenClaims: TokenClaims;
let accessToken: AccessToken;
let refreshToken: RefreshToken;
let expiresOn: string;
try {
const { accessToken: at, refreshToken: rt, tokenClaims: tc, expiresOn: eo } = await this.getTokenWithAuthCode(authResponse.authCode, authResponse.codeVerifier, this.redirectUri);
tokenClaims = tc;
accessToken = at;
refreshToken = rt;
expiresOn = eo;
} catch (ex) {
if (ex.msg) {
vscode.window.showErrorMessage(ex.msg);
}
Logger.error(ex);
}
if (!accessToken) {
const msg = localize('azure.tokenFail', "Failure when retrieving tokens.");
authCompleteDeferred.reject(new Error(msg));
throw Error('Failure when retrieving tokens');
authResponse = await this.loginDesktop(tenant, resource, authCompletePromise);
}
return {
tokenRefreshResponse: { accessToken, refreshToken, tokenClaims, expiresOn },
authCompleteDeferred
response: await this.getTokenWithAuthorizationCode(tenant, resource, authResponse),
authComplete: authCompleteDeferred
};
}
public async autoOAuthCancelled(): Promise<void> {
return this.server.shutdown();
}
public async loginWithLocalServer(authCompletePromise: Promise<void>, resourceId: string, tenant: string = this.commonTenant.id): Promise<AuthCodeResponse | undefined> {
this.server = new SimpleWebServer();
const nonce = crypto.randomBytes(16).toString('base64');
let serverPort: string;
try {
serverPort = await this.server.startup();
} catch (err) {
const msg = localize('azure.serverCouldNotStart', 'Server could not start. This could be a permissions error or an incompatibility on your system. You can try enabling device code authentication from settings.');
vscode.window.showErrorMessage(msg);
Logger.error(JSON.stringify(err));
return undefined;
}
// The login code to use
let loginUrl: string;
let codeVerifier: string;
{
codeVerifier = this.toBase64UrlEncoding(crypto.randomBytes(32).toString('base64'));
const state = `${serverPort},${encodeURIComponent(nonce)}`;
const codeChallenge = this.toBase64UrlEncoding(crypto.createHash('sha256').update(codeVerifier).digest('base64'));
const loginQuery = {
response_type: 'code',
response_mode: 'query',
client_id: this.clientId,
redirect_uri: this.redirectUri,
state,
prompt: 'select_account',
code_challenge_method: 'S256',
code_challenge: codeChallenge,
resource: resourceId
};
loginUrl = `${this.loginEndpointUrl}${tenant}/oauth2/authorize?${qs.stringify(loginQuery)}`;
}
await vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${serverPort}/signin?nonce=${encodeURIComponent(nonce)}`));
const authCode = await this.addServerListeners(this.server, nonce, loginUrl, authCompletePromise);
return {
authCode,
codeVerifier
/**
* Requests an OAuthTokenResponse from Microsoft OAuth
*
* @param tenant
* @param resource
* @param authCode
* @param redirectUri
* @param codeVerifier
*/
private async getTokenWithAuthorizationCode(tenant: Tenant, resource: Resource, { authCode, redirectUri, codeVerifier }: AuthCodeResponse): Promise<OAuthTokenResponse | undefined> {
const postData: AuthorizationCodePostData = {
grant_type: 'authorization_code',
code: authCode,
client_id: this.clientId,
code_verifier: codeVerifier,
redirect_uri: redirectUri,
resource: resource.endpoint
};
return this.getToken(tenant, resource, postData);
}
public async loginWithoutLocalServer(resourceId: string, tenant: string = this.commonTenant.id): Promise<AuthCodeResponse | undefined> {
private async loginWeb(tenant: Tenant, resource: Resource): Promise<AuthCodeResponse> {
const callbackUri = await vscode.env.asExternalUri(vscode.Uri.parse(`${vscode.env.uriScheme}://microsoft.azurecore`));
const nonce = crypto.randomBytes(16).toString('base64');
const { nonce, codeVerifier, codeChallenge } = this.createCryptoValues();
const port = (callbackUri.authority.match(/:([0-9]*)$/) || [])[1] || (callbackUri.scheme === 'https' ? 443 : 80);
const state = `${port},${encodeURIComponent(nonce)},${encodeURIComponent(callbackUri.query)}`;
const codeVerifier = this.toBase64UrlEncoding(crypto.randomBytes(32).toString('base64'));
const codeChallenge = this.toBase64UrlEncoding(crypto.createHash('sha256').update(codeVerifier).digest('base64'));
const loginQuery = {
response_type: 'code',
response_mode: 'query',
@@ -165,26 +100,27 @@ export class AzureAuthCodeGrant extends AzureAuth {
prompt: 'select_account',
code_challenge_method: 'S256',
code_challenge: codeChallenge,
resource: resourceId
resource: resource.id
};
const signInUrl = `${this.loginEndpointUrl}${tenant}/oauth2/authorize?${qs.stringify(loginQuery)}`;
await vscode.env.openExternal(vscode.Uri.parse(signInUrl));
const authCode = await this.handleCodeResponse(state);
const authCode = await this.handleWebResponse(state);
return {
authCode,
codeVerifier
codeVerifier,
redirectUri: this.redirectUri
};
}
public async handleCodeResponse(state: string): Promise<string> {
private async handleWebResponse(state: string): Promise<string> {
let uriEventListener: vscode.Disposable;
return new Promise((resolve: (value: any) => void, reject) => {
uriEventListener = this.uriEventEmitter.event(async (uri: vscode.Uri) => {
try {
const query = parseQuery(uri);
const query = this.parseQuery(uri);
const code = query.code;
if (query.state !== state && decodeURIComponent(query.state) !== state) {
reject(new Error('State mismatch'));
@@ -200,33 +136,46 @@ export class AzureAuthCodeGrant extends AzureAuth {
});
}
public async login(): Promise<azdata.Account | azdata.PromptFailedResult> {
const { tokenRefreshResponse, authCompleteDeferred } = await this.promptForConsent(this.metadata.settings.signInResourceId);
const { accessToken, refreshToken, tokenClaims } = tokenRefreshResponse;
private parseQuery(uri: vscode.Uri): { [key: string]: string } {
return uri.query.split('&').reduce((prev: any, current) => {
const queryString = current.split('=');
prev[queryString[0]] = queryString[1];
return prev;
}, {});
}
const tenants = await this.getTenants(accessToken);
private async loginDesktop(tenant: Tenant, resource: Resource, authCompletePromise: Promise<void>): Promise<AuthCodeResponse> {
this.server = new SimpleWebServer();
let serverPort: string;
try {
await this.setCachedToken({ accountId: accessToken.key, providerId: this.metadata.id }, accessToken, refreshToken);
serverPort = await this.server.startup();
} catch (ex) {
Logger.error(ex);
if (ex.msg) {
vscode.window.showErrorMessage(ex.msg);
authCompleteDeferred.reject(ex);
} else {
authCompleteDeferred.reject(new Error('There was an issue when storing the cache.'));
}
return { canceled: false } as azdata.PromptFailedResult;
const msg = localize('azure.serverCouldNotStart', 'Server could not start. This could be a permissions error or an incompatibility on your system. You can try enabling device code authentication from settings.');
throw new AzureAuthError(msg, 'Server could not start', ex);
}
const { nonce, codeVerifier, codeChallenge } = this.createCryptoValues();
const state = `${serverPort},${encodeURIComponent(nonce)}`;
const loginQuery = {
response_type: 'code',
response_mode: 'query',
client_id: this.clientId,
redirect_uri: this.redirectUri,
state,
prompt: 'select_account',
code_challenge_method: 'S256',
code_challenge: codeChallenge,
resource: resource.endpoint
};
const loginUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/authorize?${qs.stringify(loginQuery)}`;
await vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${serverPort}/signin?nonce=${encodeURIComponent(nonce)}`));
const authCode = await this.addServerListeners(this.server, nonce, loginUrl, authCompletePromise);
return {
authCode,
codeVerifier,
redirectUri: this.redirectUri
};
const account = this.createAccount(tokenClaims, accessToken.key, tenants);
const subscriptions = await this.getSubscriptions(account);
account.properties.subscriptions = subscriptions;
authCompleteDeferred.resolve();
return account;
}
private async addServerListeners(server: SimpleWebServer, nonce: string, loginUrl: string, authComplete: Promise<void>): Promise<string> {
@@ -266,7 +215,7 @@ export class AzureAuthCodeGrant extends AzureAuth {
if (receivedNonce !== nonce) {
res.writeHead(400, { 'content-type': 'text/html' });
res.write(localize('azureAuth.nonceError', "Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again."));
res.write(localize('azureAuth.nonceError', 'Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again.'));
res.end();
Logger.error('nonce no match', receivedNonce, nonce);
return;
@@ -283,7 +232,7 @@ export class AzureAuthCodeGrant extends AzureAuth {
const stateSplit = state.split(',');
if (stateSplit.length !== 2) {
res.writeHead(400, { 'content-type': 'text/html' });
res.write(localize('azureAuth.stateError', "Authentication failed due to a state mismatch, please close ADS and try again."));
res.write(localize('azureAuth.stateError', 'Authentication failed due to a state mismatch, please close ADS and try again.'));
res.end();
reject(new Error('State mismatch'));
return;
@@ -291,7 +240,7 @@ export class AzureAuthCodeGrant extends AzureAuth {
if (stateSplit[1] !== encodeURIComponent(nonce)) {
res.writeHead(400, { 'content-type': 'text/html' });
res.write(localize('azureAuth.nonceError', "Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again."));
res.write(localize('azureAuth.nonceError', 'Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again.'));
res.end();
reject(new Error('Nonce mismatch'));
return;
@@ -310,20 +259,14 @@ export class AzureAuthCodeGrant extends AzureAuth {
});
}
private async getTokenWithAuthCode(authCode: string, codeVerifier: string, redirectUri: string): Promise<TokenRefreshResponse | undefined> {
const postData = {
grant_type: 'authorization_code',
code: authCode,
client_id: this.clientId,
code_verifier: codeVerifier,
redirect_uri: redirectUri,
resource: this.metadata.settings.signInResourceId
private createCryptoValues(): CryptoValues {
const nonce = crypto.randomBytes(16).toString('base64');
const codeVerifier = this.toBase64UrlEncoding(crypto.randomBytes(32).toString('base64'));
const codeChallenge = this.toBase64UrlEncoding(crypto.createHash('sha256').update(codeVerifier).digest('base64'));
return {
nonce, codeVerifier, codeChallenge
};
return this.getToken(postData);
}
public dispose() {
this.server?.shutdown().catch(console.error);
}
}

View File

@@ -0,0 +1,24 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export class AzureAuthError extends Error {
private readonly _originalMessage: string;
constructor(localizedMessage: string, _originalMessage: string, private readonly originalException: any) {
super(localizedMessage);
}
get originalMessage(): string {
return this._originalMessage;
}
getPrintableString(): string {
return JSON.stringify({
originalMessage: this.originalMessage,
originalException: this.originalException
}, undefined, 2);
}
}

View File

@@ -8,21 +8,24 @@ import * as vscode from 'vscode';
import * as nls from 'vscode-nls';
import {
AzureAuth,
TokenClaims,
AccessToken,
RefreshToken,
OAuthTokenResponse,
DeviceCodeStartPostData,
DeviceCodeCheckPostData,
} from './azureAuth';
import {
AzureAccountProviderMetadata,
AzureAccount,
AzureAuthType,
Tenant,
Resource,
Deferred,
// Tenant,
// Subscription
} from '../interfaces';
import { SimpleTokenCache } from '../simpleTokenCache';
import { Logger } from '../../utils/Logger';
const localize = nls.loadMessageBundle();
interface DeviceCodeLogin { // https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code
@@ -43,7 +46,6 @@ interface DeviceCodeLoginResult {
}
export class AzureDeviceCode extends AzureAuth {
private static readonly USER_FRIENDLY_NAME: string = localize('azure.azureDeviceCodeAuth', "Azure Device Code");
private readonly pageTitle: string;
constructor(
@@ -56,60 +58,42 @@ export class AzureDeviceCode extends AzureAuth {
this.pageTitle = localize('addAccount', "Add {0} account", this.metadata.displayName);
}
protected async login(tenant: Tenant, resource: Resource): Promise<{ response: OAuthTokenResponse, authComplete: Deferred<void> }> {
let authCompleteDeferred: Deferred<void>;
let authCompletePromise = new Promise<void>((resolve, reject) => authCompleteDeferred = { resolve, reject });
public async promptForConsent(resourceId: string, tenant: string = this.commonTenant.id): Promise<undefined> {
vscode.window.showErrorMessage(localize('azure.deviceCodeDoesNotSupportConsent', "Device code authentication does not support prompting for consent. Switch the authentication method in settings to code grant."));
return undefined;
const uri = `${this.loginEndpointUrl}/${this.commonTenant.id}/oauth2/devicecode`;
const postData: DeviceCodeStartPostData = {
client_id: this.clientId,
resource: resource.endpoint
};
const postResult = await this.makePostRequest(uri, postData);
const initialDeviceLogin: DeviceCodeLogin = postResult.data;
await azdata.accounts.beginAutoOAuthDeviceCode(this.metadata.id, this.pageTitle, initialDeviceLogin.message, initialDeviceLogin.user_code, initialDeviceLogin.verification_url);
const finalDeviceLogin = await this.setupPolling(initialDeviceLogin);
const accessTokenString = finalDeviceLogin.access_token;
const refreshTokenString = finalDeviceLogin.refresh_token;
const currentTime = new Date().getTime() / 1000;
const expiresOn = `${currentTime + finalDeviceLogin.expires_in}`;
const result = await this.getTokenHelper(tenant, resource, accessTokenString, refreshTokenString, expiresOn);
this.closeOnceComplete(authCompletePromise).catch(Logger.error);
return {
response: result,
authComplete: authCompleteDeferred
};
}
public async login(): Promise<AzureAccount | azdata.PromptFailedResult> {
try {
const uri = `${this.loginEndpointUrl}/${this.commonTenant}/oauth2/devicecode`;
const postResult = await this.makePostRequest(uri, {
client_id: this.clientId,
resource: this.metadata.settings.signInResourceId
});
const initialDeviceLogin: DeviceCodeLogin = postResult.data;
await azdata.accounts.beginAutoOAuthDeviceCode(this.metadata.id, this.pageTitle, initialDeviceLogin.message, initialDeviceLogin.user_code, initialDeviceLogin.verification_url);
const finalDeviceLogin = await this.setupPolling(initialDeviceLogin);
let tokenClaims: TokenClaims;
let accessToken: AccessToken;
let refreshToken: RefreshToken;
// let tenants: Tenant[];
// let subscriptions: Subscription[];
tokenClaims = this.getTokenClaims(finalDeviceLogin.access_token);
accessToken = {
token: finalDeviceLogin.access_token,
key: tokenClaims.email || tokenClaims.unique_name || tokenClaims.name,
};
refreshToken = {
token: finalDeviceLogin.refresh_token,
key: accessToken.key,
};
await this.setCachedToken({ accountId: accessToken.key, providerId: this.metadata.id }, accessToken, refreshToken);
const tenants = await this.getTenants(accessToken);
const account = this.createAccount(tokenClaims, accessToken.key, tenants);
const subscriptions = await this.getSubscriptions(account);
account.properties.subscriptions = subscriptions;
return account;
} catch (ex) {
console.log(ex);
if (ex.msg) {
vscode.window.showErrorMessage(ex.msg);
}
return { canceled: false };
} finally {
azdata.accounts.endAutoOAuthDeviceCode();
}
private async closeOnceComplete(promise: Promise<void>): Promise<void> {
await promise;
azdata.accounts.endAutoOAuthDeviceCode();
}
@@ -141,14 +125,14 @@ export class AzureDeviceCode extends AzureAuth {
const msg = localize('azure.deviceCodeCheckFail', "Error encountered when trying to check for login results");
try {
const uri = `${this.loginEndpointUrl}/${this.commonTenant}/oauth2/token`;
const postData = {
const postData: DeviceCodeCheckPostData = {
grant_type: 'urn:ietf:params:oauth:grant-type:device_code',
client_id: this.clientId,
tenant: this.commonTenant.id,
code: info.device_code
};
const postResult = await this.makePostRequest(uri, postData, true);
const postResult = await this.makePostRequest(uri, postData);
const result: DeviceCodeLoginResult = postResult.data;
@@ -164,5 +148,4 @@ export class AzureDeviceCode extends AzureAuth {
public async autoOAuthCancelled(): Promise<void> {
return azdata.accounts.endAutoOAuthDeviceCode();
}
}

View File

@@ -10,14 +10,15 @@ import * as nls from 'vscode-nls';
import {
AzureAccountProviderMetadata,
AzureAuthType,
Deferred
Deferred,
AzureAccount
} from './interfaces';
import { SimpleTokenCache } from './simpleTokenCache';
import { AzureAuth, TokenResponse } from './auths/azureAuth';
import { Logger } from '../utils/Logger';
import { MultiTenantTokenResponse, Token, AzureAuth } from './auths/azureAuth';
import { AzureAuthCodeGrant } from './auths/azureAuthCodeGrant';
import { AzureDeviceCode } from './auths/azureDeviceCode';
import { Logger } from '../utils/Logger';
const localize = nls.loadMessageBundle();
@@ -101,14 +102,29 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp
}
getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<TokenResponse | undefined> {
getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<MultiTenantTokenResponse | undefined> {
return this._getSecurityToken(account, resource);
}
private async _getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Promise<TokenResponse | undefined> {
getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<Token | undefined> {
return this._getAccountSecurityToken(account, tenant, resource);
}
private async _getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Promise<Token | undefined> {
await this.initCompletePromise;
const azureAuth = this.getAuthMethod(undefined);
return azureAuth?.getSecurityToken(account, resource);
return azureAuth?.getAccountSecurityToken(account, tenant, resource);
}
private async _getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Promise<MultiTenantTokenResponse | undefined> {
vscode.window.showInformationMessage(localize('azure.deprecatedGetSecurityToken', "A call was made to azdata.accounts.getSecurityToken, this method is deprecated and will be removed in future releases. Please use getAccountSecurityToken instead."));
const azureAccount = account as AzureAccount;
const response: MultiTenantTokenResponse = {};
for (const tenant of azureAccount.properties.tenants) {
response[tenant.id] = await this._getAccountSecurityToken(account, tenant.id, resource);
}
return response;
}
prompt(): Thenable<azdata.Account | azdata.PromptFailedResult> {
@@ -134,7 +150,7 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp
}
if (this.authMappings.size === 1) {
return this.getAuthMethod(undefined).login();
return this.getAuthMethod(undefined).startLogin();
}
const options: Option[] = [];
@@ -150,7 +166,7 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp
return { canceled: true };
}
return pick.azureAuth.login();
return pick.azureAuth.startLogin();
}
refresh(account: azdata.Account): Thenable<azdata.Account | azdata.PromptFailedResult> {

View File

@@ -64,11 +64,6 @@ interface Settings {
*/
clientId?: string;
/**
* Identifier of the resource to request when signing in
*/
signInResourceId?: string;
/**
* Information that describes the Microsoft resource management resource
*/
@@ -177,10 +172,6 @@ interface AzureAccountProperties {
*/
tenants: Tenant[];
/**
* A list of subscriptions the user belongs to
*/
subscriptions?: Subscription[];
}
export interface Subscription {

View File

@@ -17,7 +17,6 @@ const publicAzureSettings: ProviderSettings = {
settings: {
host: 'https://login.microsoftonline.com/',
clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255',
signInResourceId: 'https://management.core.windows.net/',
microsoftResource: {
id: 'marm',
endpoint: 'https://management.core.windows.net/',
@@ -72,7 +71,6 @@ const usGovAzureSettings: ProviderSettings = {
settings: {
host: 'https://login.microsoftonline.us/',
clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255',
signInResourceId: 'https://management.core.usgovcloudapi.net/',
microsoftResource: {
id: 'marm',
endpoint: 'https://management.core.usgovcloudapi.net/',
@@ -121,7 +119,6 @@ const usNatAzureSettings: ProviderSettings = {
settings: {
host: 'https://login.microsoftonline.eaglex.ic.gov/',
clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255',
signInResourceId: 'https://management.core.eaglex.ic.gov/',
microsoftResource: {
id: 'marm',
endpoint: 'https://management.azure.eaglex.ic.gov/',
@@ -171,7 +168,6 @@ const germanyAzureSettings: ProviderSettings = {
settings: {
host: 'https://login.microsoftazure.de/',
clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255',
signInResourceId: 'https://management.core.cloudapi.de/',
graphResource: {
id: 'https://graph.cloudapi.de/',
endpoint: 'https://graph.cloudapi.de'
@@ -197,7 +193,6 @@ const chinaAzureSettings: ProviderSettings = {
settings: {
host: 'https://login.chinacloudapi.cn/',
clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255',
signInResourceId: 'https://management.core.chinacloudapi.cn/',
graphResource: {
id: 'https://graph.chinacloudapi.cn/',
endpoint: 'https://graph.chinacloudapi.cn'

View File

@@ -35,7 +35,6 @@ export function registerAzureResourceCommands(appContext: AppContext, tree: Azur
const accountNode = node as AzureResourceAccountTreeNode;
const azureAccount = accountNode.account as AzureAccount;
const tokens = await azdata.accounts.getSecurityToken(azureAccount, azdata.AzureResource.MicrosoftResourceManagement);
const terminalService = appContext.getService<IAzureTerminalService>(AzureResourceServiceNames.terminalService);
@@ -64,7 +63,7 @@ export function registerAzureResourceCommands(appContext: AppContext, tree: Azur
tenant = azureAccount.properties.tenants[listOfTenants.indexOf(pickedTenant)];
}
await terminalService.getOrCreateCloudConsole(azureAccount, tenant, tokens);
await terminalService.getOrCreateCloudConsole(azureAccount, tenant);
} catch (ex) {
console.error(ex);
vscode.window.showErrorMessage(ex);
@@ -91,13 +90,14 @@ export function registerAzureResourceCommands(appContext: AppContext, tree: Azur
const subscriptions = (await accountNode.getCachedSubscriptions()) || <azureResource.AzureResourceSubscription[]>[];
if (subscriptions.length === 0) {
try {
const tokens = await azdata.accounts.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
for (const tenant of account.properties.tenants) {
const token = tokens[tenant.id].token;
const tokenType = tokens[tenant.id].tokenType;
const response = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement);
subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType)));
const token = response.token;
const tokenType = response.tokenType;
subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType), tenant.id));
}
} catch (error) {
account.isStale = true;

View File

@@ -8,10 +8,10 @@ import * as msRest from '@azure/ms-rest-js';
import { Account } from 'azdata';
import { azureResource } from './azure-resource';
import { AzureAccount, AzureAccountSecurityToken, Tenant } from '../account-provider/interfaces';
import { AzureAccount, Tenant } from '../account-provider/interfaces';
export interface IAzureResourceSubscriptionService {
getSubscriptions(account: Account, credential: msRest.ServiceClientCredentials): Promise<azureResource.AzureResourceSubscription[]>;
getSubscriptions(account: Account, credential: msRest.ServiceClientCredentials, tenantId: string): Promise<azureResource.AzureResourceSubscription[]>;
}
export interface IAzureResourceSubscriptionFilterService {
@@ -20,7 +20,7 @@ export interface IAzureResourceSubscriptionFilterService {
}
export interface IAzureTerminalService {
getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant, tokens: { [key: string]: AzureAccountSecurityToken }): Promise<void>;
getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant): Promise<void>;
}
export interface IAzureResourceCacheService {
@@ -31,9 +31,6 @@ export interface IAzureResourceCacheService {
update<T>(key: string, value: T): void;
}
export interface IAzureResourceTenantService {
getTenantId(subscription: azureResource.AzureResourceSubscription, account: Account, credential: msRest.ServiceClientCredentials): Promise<string>;
}
export interface IAzureResourceNodeWithProviderId {
resourceProviderId: string;

View File

@@ -41,8 +41,8 @@ export abstract class ResourceTreeDataProviderBase<T extends azureResource.Azure
}
private async getResources(element: azureResource.IAzureResourceNode): Promise<T[]> {
const tokens = await azdata.accounts.getSecurityToken(element.account, azdata.AzureResource.ResourceManagement);
const credential = new msRest.TokenCredentials(tokens[element.tenantId].token, tokens[element.tenantId].tokenType);
const response = await azdata.accounts.getAccountSecurityToken(element.account, element.tenantId, azdata.AzureResource.ResourceManagement);
const credential = new msRest.TokenCredentials(response.token, response.tokenType);
const resources: T[] = await this._resourceService.getResources(element.subscription, credential, element.account) || <T[]>[];
return resources;

View File

@@ -10,14 +10,15 @@ import { azureResource } from '../azure-resource';
import { IAzureResourceSubscriptionService } from '../interfaces';
export class AzureResourceSubscriptionService implements IAzureResourceSubscriptionService {
public async getSubscriptions(account: Account, credential: any): Promise<azureResource.AzureResourceSubscription[]> {
public async getSubscriptions(account: Account, credential: any, tenantId: string): Promise<azureResource.AzureResourceSubscription[]> {
const subscriptions: azureResource.AzureResourceSubscription[] = [];
const subClient = new SubscriptionClient(credential, { baseUri: account.properties.providerSettings.settings.armResource.endpoint });
const subs = await subClient.subscriptions.list();
subs.forEach((sub) => subscriptions.push({
id: sub.subscriptionId,
name: sub.displayName
name: sub.displayName,
tenant: tenantId
}));
return subscriptions;

View File

@@ -1,18 +0,0 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { SubscriptionClient } from '@azure/arm-subscriptions';
import { azureResource } from '../azure-resource';
import { IAzureResourceTenantService } from '../interfaces';
import { Account } from 'azdata';
export class AzureResourceTenantService implements IAzureResourceTenantService {
public async getTenantId(subscription: azureResource.AzureResourceSubscription, account: Account, credentials: any): Promise<string> {
const subClient = new SubscriptionClient(credentials, { baseUri: account.properties.providerSettings.settings.armResource.endpoint });
const result = await subClient.subscriptions.get(subscription.id);
return result.subscriptionId;
}
}

View File

@@ -2,14 +2,14 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as nls from 'vscode-nls';
import axios, { AxiosRequestConfig, AxiosResponse } from 'axios';
import * as WS from 'ws';
import { IAzureTerminalService } from '../interfaces';
import { AzureAccount, AzureAccountSecurityToken, Tenant } from '../../account-provider/interfaces';
import { AzureAccount, Tenant } from '../../account-provider/interfaces';
const localize = nls.loadMessageBundle();
@@ -48,13 +48,13 @@ export class AzureTerminalService implements IAzureTerminalService {
}
public async getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant, tokens: { [key: string]: AzureAccountSecurityToken }): Promise<void> {
const token = tokens[tenant.id].token;
public async getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant): Promise<void> {
const token = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.MicrosoftResourceManagement);
const settings: AxiosRequestConfig = {
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': `Bearer ${token}`
'Authorization': `Bearer ${token.token}`
},
validateStatus: () => true
};
@@ -93,7 +93,7 @@ export class AzureTerminalService implements IAzureTerminalService {
}
const consoleUri = provisionResult.data.properties.uri;
return this.createTerminal(consoleUri, token, account.displayInfo.displayName, preferredShell);
return this.createTerminal(consoleUri, token.token, account.displayInfo.displayName, preferredShell);
}

View File

@@ -20,11 +20,12 @@ import { AzureResourceSubscriptionTreeNode } from './subscriptionTreeNode';
import { AzureResourceMessageTreeNode } from '../messageTreeNode';
import { AzureResourceErrorMessageUtil } from '../utils';
import { IAzureResourceTreeChangeHandler } from './treeChangeHandler';
import { IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService, IAzureResourceTenantService } from '../../azureResource/interfaces';
import { IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService } from '../../azureResource/interfaces';
import { AzureAccount } from '../../account-provider/interfaces';
export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNodeBase {
public constructor(
public readonly account: azdata.Account,
public readonly account: AzureAccount,
appContext: AppContext,
treeChangeHandler: IAzureResourceTreeChangeHandler
) {
@@ -32,7 +33,6 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode
this._subscriptionService = this.appContext.getService<IAzureResourceSubscriptionService>(AzureResourceServiceNames.subscriptionService);
this._subscriptionFilterService = this.appContext.getService<IAzureResourceSubscriptionFilterService>(AzureResourceServiceNames.subscriptionFilterService);
this._tenantService = this.appContext.getService<IAzureResourceTenantService>(AzureResourceServiceNames.tenantService);
this._id = `account_${this.account.key.accountId}`;
this.setCacheKey(`${this._id}.subscriptions`);
@@ -42,15 +42,13 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode
public async getChildren(): Promise<TreeNode[]> {
try {
let subscriptions: azureResource.AzureResourceSubscription[] = [];
const tokens = await azdata.accounts.getSecurityToken(this.account, azdata.AzureResource.ResourceManagement);
if (this._isClearingCache) {
try {
for (const tenant of this.account.properties.tenants) {
const token = tokens[tenant.id].token;
const tokenType = tokens[tenant.id].tokenType;
const token = await azdata.accounts.getAccountSecurityToken(this.account, tenant.id, azdata.AzureResource.ResourceManagement);
subscriptions.push(...(await this._subscriptionService.getSubscriptions(this.account, new TokenCredentials(token, tokenType)) || <azureResource.AzureResourceSubscription[]>[]));
subscriptions.push(...(await this._subscriptionService.getSubscriptions(this.account, new TokenCredentials(token.token, token.tokenType), tenant.id) || <azureResource.AzureResourceSubscription[]>[]));
}
} catch (error) {
throw new AzureResourceCredentialError(localize('azure.resource.tree.accountTreeNode.credentialError', "Failed to get credential for account {0}. Please refresh the account.", this.account.key.accountId), error);
@@ -80,8 +78,8 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode
return [AzureResourceMessageTreeNode.create(AzureResourceAccountTreeNode.noSubscriptionsLabel, this)];
} else {
// Filter out everything that we can't authenticate to.
subscriptions = subscriptions.filter(s => {
const token = tokens[s.id];
subscriptions = subscriptions.filter(async s => {
const token = await azdata.accounts.getAccountSecurityToken(this.account, s.tenant, azdata.AzureResource.ResourceManagement);
if (!token) {
console.info(`Account does not have permissions to view subscription ${JSON.stringify(s)}.`);
return false;
@@ -90,10 +88,7 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode
});
let subTreeNodes = await Promise.all(subscriptions.map(async (subscription) => {
const token = tokens[subscription.id];
const tenantId = await this._tenantService.getTenantId(subscription, this.account, new TokenCredentials(token.token, token.tokenType));
return new AzureResourceSubscriptionTreeNode(this.account, subscription, tenantId, this.appContext, this.treeChangeHandler, this);
return new AzureResourceSubscriptionTreeNode(this.account, subscription, subscription.tenant, this.appContext, this.treeChangeHandler, this);
}));
return subTreeNodes.sort((a, b) => a.subscription.name.localeCompare(b.subscription.name));
}
@@ -166,7 +161,6 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode
private _subscriptionService: IAzureResourceSubscriptionService = undefined;
private _subscriptionFilterService: IAzureResourceSubscriptionFilterService = undefined;
private _tenantService: IAzureResourceTenantService = undefined;
private _id: string = undefined;
private _label: string = undefined;

View File

@@ -151,13 +151,13 @@ export async function getSubscriptions(appContext: AppContext, account?: azdata.
}
const subscriptionService = appContext.getService<IAzureResourceSubscriptionService>(AzureResourceServiceNames.subscriptionService);
const tokens = await azdata.accounts.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
await Promise.all(account.properties.tenants.map(async (tenant: { id: string | number; }) => {
await Promise.all(account.properties.tenants.map(async (tenant: { id: string; }) => {
try {
const token = tokens[tenant.id].token;
const tokenType = tokens[tenant.id].tokenType;
const response = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement);
const token = response.token;
const tokenType = response.tokenType;
result.subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType)));
result.subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType), tenant.id));
} catch (err) {
const error = new Error(localize('azure.accounts.getSubscriptions.queryError', "Error fetching subscriptions for account {0} tenant {1} : {2}",
account.displayInfo.displayName,

View File

@@ -17,12 +17,11 @@ import { AzureResourceDatabaseServerService } from './azureResource/providers/da
import { AzureResourceDatabaseProvider } from './azureResource/providers/database/databaseProvider';
import { AzureResourceDatabaseService } from './azureResource/providers/database/databaseService';
import { AzureResourceService } from './azureResource/resourceService';
import { IAzureResourceCacheService, IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService, IAzureResourceTenantService, IAzureTerminalService } from './azureResource/interfaces';
import { IAzureResourceCacheService, IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService, IAzureTerminalService } from './azureResource/interfaces';
import { AzureResourceServiceNames } from './azureResource/constants';
import { AzureResourceSubscriptionService } from './azureResource/services/subscriptionService';
import { AzureResourceSubscriptionFilterService } from './azureResource/services/subscriptionFilterService';
import { AzureResourceCacheService } from './azureResource/services/cacheService';
import { AzureResourceTenantService } from './azureResource/services/tenantService';
import { registerAzureResourceCommands } from './azureResource/commands';
import { AzureResourceTreeProvider } from './azureResource/tree/treeProvider';
import { SqlInstanceResourceService } from './azureResource/providers/sqlinstance/sqlInstanceService';
@@ -153,7 +152,6 @@ function registerAzureServices(appContext: AppContext): void {
appContext.registerService<IAzureResourceCacheService>(AzureResourceServiceNames.cacheService, new AzureResourceCacheService(extensionContext));
appContext.registerService<IAzureResourceSubscriptionService>(AzureResourceServiceNames.subscriptionService, new AzureResourceSubscriptionService());
appContext.registerService<IAzureResourceSubscriptionFilterService>(AzureResourceServiceNames.subscriptionFilterService, new AzureResourceSubscriptionFilterService(new AzureResourceCacheService(extensionContext)));
appContext.registerService<IAzureResourceTenantService>(AzureResourceServiceNames.tenantService, new AzureResourceTenantService());
appContext.registerService<IAzureTerminalService>(AzureResourceServiceNames.terminalService, new AzureTerminalService(extensionContext));
}

View File

@@ -4,114 +4,240 @@
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import * as os from 'os';
import * as TypeMoq from 'typemoq';
// import * as azdata from 'azdata';
// import * as vscode from 'vscode';
// import * as sinon from 'sinon';
import 'mocha';
import * as vscode from 'vscode';
import { PromptFailedResult, AccountKey } from 'azdata';
import { AzureAuth, AccessToken, RefreshToken, TokenClaims, TokenRefreshResponse } from '../../../account-provider/auths/azureAuth';
import { AzureAccount, AzureAuthType, Deferred, Tenant } from '../../../account-provider/interfaces';
import { AzureAuthCodeGrant } from '../../../account-provider/auths/azureAuthCodeGrant';
// import { AzureDeviceCode } from '../../../account-provider/auths/azureDeviceCode';
import { Token, TokenClaims, AccessToken, RefreshToken, OAuthTokenResponse, TokenPostData } from '../../../account-provider/auths/azureAuth';
import { Tenant, AzureAccount } from '../../../account-provider/interfaces';
import providerSettings from '../../../account-provider/providerSettings';
import { SimpleTokenCache } from '../../../account-provider/simpleTokenCache';
import { CredentialsTestProvider } from '../../stubs/credentialsTestProvider';
import { AzureResource } from 'azdata';
import { AxiosResponse } from 'axios';
class BasicAzureAuth extends AzureAuth {
public async login(): Promise<AzureAccount | PromptFailedResult> {
throw new Error('Method not implemented.');
}
public async autoOAuthCancelled(): Promise<void> {
throw new Error('Method not implemented.');
}
let azureAuthCodeGrant: TypeMoq.IMock<AzureAuthCodeGrant>;
// let azureDeviceCode: TypeMoq.IMock<AzureDeviceCode>;
public async promptForConsent(): Promise<{ tokenRefreshResponse: TokenRefreshResponse, authCompleteDeferred: Deferred<void> } | undefined> {
throw new Error('Method not implemented.');
}
}
const mockToken: Token = {
key: 'someUniqueId',
token: 'test_token',
tokenType: 'Bearer'
};
let mockAccessToken: AccessToken;
let mockRefreshToken: RefreshToken;
let baseAuth: AzureAuth;
const mockClaims = {
name: 'Name',
email: 'example@example.com',
sub: 'someUniqueId'
} as TokenClaims;
const accountKey: AccountKey = {
accountId: 'SomeAccountKey',
providerId: 'providerId',
const mockTenant: Tenant = {
displayName: 'Tenant Name',
id: 'tenantID',
tenantCategory: 'Home',
userId: 'test_user'
};
const accessToken: AccessToken = {
key: accountKey.accountId,
token: '123'
};
let mockAccount: AzureAccount;
const refreshToken: RefreshToken = {
key: accountKey.accountId,
token: '321'
};
const provider = providerSettings[0].metadata;
const resourceId = 'resource';
const tenantId = 'tenant';
describe('Azure Authentication', function () {
beforeEach(function () {
azureAuthCodeGrant = TypeMoq.Mock.ofType<AzureAuthCodeGrant>(AzureAuthCodeGrant, TypeMoq.MockBehavior.Loose, true, provider);
// azureDeviceCode = TypeMoq.Mock.ofType<AzureDeviceCode>();
const tenant: Tenant = {
id: tenantId,
displayName: 'common'
};
azureAuthCodeGrant.callBase = true;
// authDeviceCode.callBase = true;
// These tests don't work on Linux systems because gnome-keyring doesn't like running on headless machines.
describe('AccountProvider.AzureAuth', function (): void {
beforeEach(async function (): Promise<void> {
const tokenCache = new SimpleTokenCache('testTokenService', os.tmpdir(), true, new CredentialsTestProvider());
await tokenCache.init();
baseAuth = new BasicAzureAuth(providerSettings[0].metadata, tokenCache, undefined, undefined, AzureAuthType.AuthCodeGrant, 'Auth Code Grant');
mockAccount = {
isStale: false,
properties: {
tenants: [mockTenant]
}
} as AzureAccount;
mockAccessToken = {
...mockToken
};
mockRefreshToken = {
...mockToken
};
});
it('Basic token set and get', async function (): Promise<void> {
await baseAuth.setCachedToken(accountKey, accessToken, refreshToken);
const result = await baseAuth.getCachedToken(accountKey);
it('accountHydration should yield a valid account', async function () {
should(JSON.stringify(result.accessToken)).be.equal(JSON.stringify(accessToken));
should(JSON.stringify(result.refreshToken)).be.equal(JSON.stringify(refreshToken));
azureAuthCodeGrant.setup(x => x.getTenants(mockToken)).returns((): Promise<Tenant[]> => {
return Promise.resolve([
mockTenant
]);
});
const response = await azureAuthCodeGrant.object.hydrateAccount(mockToken, mockClaims);
should(response.displayInfo.displayName).be.equal(`${mockClaims.name} - ${mockClaims.email}`, 'Account name should match');
should(response.displayInfo.userId).be.equal(mockClaims.sub, 'Account ID should match');
should(response.properties.tenants).be.deepEqual([mockTenant], 'Tenants should match');
});
it('Token set and get with tenant and resource id', async function (): Promise<void> {
await baseAuth.setCachedToken(accountKey, accessToken, refreshToken, resourceId, tenantId);
let result = await baseAuth.getCachedToken(accountKey, resourceId, tenantId);
describe('getAccountSecurityToken', function () {
it('should be undefined on stale account', async function () {
mockAccount.isStale = true;
const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, TypeMoq.It.isAny(), TypeMoq.It.isAny());
should(securityToken).be.undefined();
});
it('dont find correct resources', async function () {
const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, TypeMoq.It.isAny(), -1);
should(securityToken).be.undefined();
});
it('incorrect tenant', async function () {
await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, 'invalid_tenant', AzureResource.MicrosoftResourceManagement).should.be.rejected();
});
should(JSON.stringify(result.accessToken)).be.equal(JSON.stringify(accessToken));
should(JSON.stringify(result.refreshToken)).be.equal(JSON.stringify(refreshToken));
it('saved token exists and can be reused', async function () {
delete (mockAccessToken as any).tokenType;
azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => {
return Promise.resolve({
accessToken: mockAccessToken,
refreshToken: mockRefreshToken,
expiresOn: `${(new Date().getTime() / 1000) + (10 * 60)}`
})
});
const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement);
should(securityToken.tokenType).be.equal('Bearer', 'tokenType should be bearer on a successful getSecurityToken from cache')
});
it('saved token had invalid expiration', async function () {
delete (mockAccessToken as any).tokenType;
(mockAccessToken as any).invalidData = 'this should not exist on response';
azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => {
return Promise.resolve({
accessToken: mockAccessToken,
refreshToken: mockRefreshToken,
expiresOn: undefined
});
});
azureAuthCodeGrant.setup(x => x.refreshToken(mockTenant, provider.settings.microsoftResource, mockRefreshToken)).returns((): Promise<OAuthTokenResponse> => {
const mockToken: AccessToken = JSON.parse(JSON.stringify(mockAccessToken));
delete (mockToken as any).invalidData;
return Promise.resolve({
accessToken: mockToken
} as OAuthTokenResponse);
});
const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement);
should((securityToken as any).invalidData).be.undefined(); // Ensure its a new one
should(securityToken.tokenType).be.equal('Bearer', 'tokenType should be bearer on a successful getSecurityToken from cache')
azureAuthCodeGrant.verify(x => x.refreshToken(mockTenant, provider.settings.microsoftResource, mockRefreshToken), TypeMoq.Times.once());
});
describe('no saved token', function () {
it('no base token', async function () {
azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => {
return Promise.resolve(undefined);
});
azureAuthCodeGrant.setup(x => x.getSavedToken(azureAuthCodeGrant.object.commonTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => {
return Promise.resolve(undefined);
});
await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement).should.be.rejected();
});
it('base token exists', async function () {
azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => {
return Promise.resolve(undefined);
});
azureAuthCodeGrant.setup(x => x.getSavedToken(azureAuthCodeGrant.object.commonTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => {
return Promise.resolve({
accessToken: mockAccessToken,
refreshToken: mockRefreshToken,
expiresOn: ''
});
});
delete (mockAccessToken as any).tokenType;
azureAuthCodeGrant.setup(x => x.refreshToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return Promise.resolve({
accessToken: mockAccessToken
} as OAuthTokenResponse);
});
const securityToken: Token = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement);
should(securityToken.tokenType).be.equal('Bearer', 'tokenType should be bearer on a successful getSecurityToken from cache')
});
});
await baseAuth.clearCredentials(accountKey);
result = await baseAuth.getCachedToken(accountKey, resourceId, tenantId);
should(result).be.undefined();
});
it('Token set with resource ID and get without tenant and resource id', async function (): Promise<void> {
await baseAuth.setCachedToken(accountKey, accessToken, refreshToken, resourceId, tenantId);
const result = await baseAuth.getCachedToken(accountKey);
describe('getToken', function () {
should(JSON.stringify(result)).be.undefined();
should(JSON.stringify(result)).be.undefined();
it('calls handle interaction required', async function () {
azureAuthCodeGrant.setup(x => x.makePostRequest(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return Promise.resolve({
data: {
error: 'interaction_required'
}
} as AxiosResponse<any>);
});
azureAuthCodeGrant.setup(x => x.handleInteractionRequired(mockTenant, provider.settings.microsoftResource)).returns(() => {
return Promise.resolve({
accessToken: mockAccessToken
} as OAuthTokenResponse);
});
const result = await azureAuthCodeGrant.object.getToken(mockTenant, provider.settings.microsoftResource, {} as TokenPostData);
azureAuthCodeGrant.verify(x => x.handleInteractionRequired(mockTenant, provider.settings.microsoftResource), TypeMoq.Times.once());
should(result.accessToken).be.deepEqual(mockAccessToken);
});
it('unknown error should throw error', async function () {
azureAuthCodeGrant.setup(x => x.makePostRequest(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return Promise.resolve({
data: {
error: 'unknown error'
}
} as AxiosResponse<any>);
});
await azureAuthCodeGrant.object.getToken(mockTenant, provider.settings.microsoftResource, {} as TokenPostData).should.be.rejected();
});
it('calls getTokenHelper', async function () {
azureAuthCodeGrant.setup(x => x.makePostRequest(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return Promise.resolve({
data: {
access_token: mockAccessToken.token,
refresh_token: mockRefreshToken.token,
expires_on: `0`
}
} as AxiosResponse<any>);
});
azureAuthCodeGrant.setup(x => x.getTokenHelper(mockTenant, provider.settings.microsoftResource, TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return Promise.resolve({
accessToken: mockAccessToken
} as OAuthTokenResponse);
});
const result = await azureAuthCodeGrant.object.getToken(mockTenant, provider.settings.microsoftResource, {} as TokenPostData);
azureAuthCodeGrant.verify(x => x.getTokenHelper(mockTenant, provider.settings.microsoftResource, TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
should(result.accessToken).be.deepEqual(mockAccessToken);
});
});
it('Create an account object', async function (): Promise<void> {
const tokenClaims = {
idp: 'live.com',
name: 'TestAccount',
} as TokenClaims;
const account = baseAuth.createAccount(tokenClaims, 'someKey', undefined);
should(account.properties.azureAuthType).be.equal(AzureAuthType.AuthCodeGrant);
should(account.key.accountId).be.equal('someKey');
should(account.properties.isMsAccount).be.equal(true);
});
it('Should handle ignored tenants', async function (): Promise<void> {
// Don't sit on the await openConsentDialog if test is failing
this.timeout(3000);
const configuration = vscode.workspace.getConfiguration('azure.tenant.config');
const values = [tenantId];
await configuration.update('filter', values, vscode.ConfigurationTarget.Global);
const x = await baseAuth.openConsentDialog(tenant, resourceId);
should(x).be.false();
});
});

View File

@@ -41,12 +41,15 @@ const mockAccount: AzureAccount = {
isStale: false
};
const mockTenantId: string = 'mock_tenant';
const mockSubscription: azureResource.AzureResourceSubscription = {
id: 'mock_subscription',
name: 'mock subscription'
name: 'mock subscription',
tenant: mockTenantId
};
const mockTenantId: string = 'mock_tenant';
const mockResourceRootNode: azureResource.IAzureResourceNode = {
account: mockAccount,
@@ -61,8 +64,7 @@ const mockResourceRootNode: azureResource.IAzureResourceNode = {
}
};
const mockTokens: { [key: string]: any } = {};
mockTokens[mockTenantId] = {
const mockToken = {
token: 'mock_token',
tokenType: 'Bearer'
};
@@ -106,7 +108,7 @@ describe('AzureResourceDatabaseTreeDataProvider.getChildren', function (): void
mockDatabaseService = TypeMoq.Mock.ofType<IAzureResourceService<azureResource.AzureResourceDatabase>>();
mockExtensionContext = TypeMoq.Mock.ofType<vscode.ExtensionContext>();
sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens));
sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken));
mockDatabaseService.setup((o) => o.getResources(mockSubscription, TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockDatabases));
mockExtensionContext.setup((o) => o.asAbsolutePath(TypeMoq.It.isAnyString())).returns(() => TypeMoq.It.isAnyString());
});

View File

@@ -41,12 +41,14 @@ const mockAccount: AzureAccount = {
isStale: false
};
const mockTenantId: string = 'mock_tenant';
const mockSubscription: azureResource.AzureResourceSubscription = {
id: 'mock_subscription',
name: 'mock subscription'
name: 'mock subscription',
tenant: mockTenantId
};
const mockTenantId: string = 'mock_tenant';
const mockResourceRootNode: azureResource.IAzureResourceNode = {
account: mockAccount,
@@ -61,8 +63,7 @@ const mockResourceRootNode: azureResource.IAzureResourceNode = {
}
};
const mockTokens: { [key: string]: any } = {};
mockTokens[mockTenantId] = {
const mockToken = {
token: 'mock_token',
tokenType: 'Bearer'
};
@@ -106,7 +107,7 @@ describe('AzureResourceDatabaseServerTreeDataProvider.getChildren', function ():
mockDatabaseServerService = TypeMoq.Mock.ofType<IAzureResourceService<azureResource.AzureResourceDatabaseServer>>();
mockExtensionContext = TypeMoq.Mock.ofType<vscode.ExtensionContext>();
sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens));
sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken));
mockDatabaseServerService.setup((o) => o.getResources(mockSubscription, TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockDatabaseServers));
mockExtensionContext.setup((o) => o.asAbsolutePath(TypeMoq.It.isAnyString())).returns(() => TypeMoq.It.isAnyString());
});

View File

@@ -33,13 +33,14 @@ const mockAccount: AzureAccount = {
isStale: false
};
const mockTenantId: string = 'mock_tenant';
const mockSubscription: azureResource.AzureResourceSubscription = {
id: 'mock_subscription',
name: 'mock subscription'
name: 'mock subscription',
tenant: mockTenantId
};
const mockTenantId: string = 'mock_tenant';
let mockResourceTreeDataProvider1: TypeMoq.IMock<azureResource.IAzureResourceTreeDataProvider>;
let mockResourceProvider1: TypeMoq.IMock<azureResource.IAzureResourceProvider>;

View File

@@ -36,12 +36,14 @@ const mockAccount: AzureAccount = {
isStale: false
};
const mockTenantId: string = 'mock_tenant';
const mockSubscription: azureResource.AzureResourceSubscription = {
id: 'mock_subscription',
name: 'mock subscription'
name: 'mock subscription',
tenant: mockTenantId
};
const mockTenantId: string = 'mock_tenant';
const mockResourceProviderId: string = 'mock_resource_provider';

View File

@@ -17,7 +17,6 @@ import {
IAzureResourceCacheService,
IAzureResourceSubscriptionService,
IAzureResourceSubscriptionFilterService,
IAzureResourceTenantService
} from '../../../azureResource/interfaces';
import { IAzureResourceTreeChangeHandler } from '../../../azureResource/tree/treeChangeHandler';
import { AzureResourceAccountTreeNode } from '../../../azureResource/tree/accountTreeNode';
@@ -31,7 +30,6 @@ let mockExtensionContext: TypeMoq.IMock<vscode.ExtensionContext>;
let mockCacheService: TypeMoq.IMock<IAzureResourceCacheService>;
let mockSubscriptionService: TypeMoq.IMock<IAzureResourceSubscriptionService>;
let mockSubscriptionFilterService: TypeMoq.IMock<IAzureResourceSubscriptionFilterService>;
let mockTenantService: TypeMoq.IMock<IAzureResourceTenantService>;
let mockAppContext: AppContext;
let getSecurityTokenStub: sinon.SinonStub;
let mockTreeChangeHandler: TypeMoq.IMock<IAzureResourceTreeChangeHandler>;
@@ -63,28 +61,27 @@ const mockAccount: azdata.Account = {
const mockSubscription1: azureResource.AzureResourceSubscription = {
id: 'mock_subscription_1',
name: 'mock subscription 1'
name: 'mock subscription 1',
tenant: mockTenantId
};
const mockSubscription2: azureResource.AzureResourceSubscription = {
id: 'mock_subscription_2',
name: 'mock subscription 2'
name: 'mock subscription 2',
tenant: mockTenantId
};
const mockSubscriptions = [mockSubscription1, mockSubscription2];
const mockFilteredSubscriptions = [mockSubscription1];
const mockTokens: { [key: string]: any } = {};
const mockToken = {
token: 'mock_token',
tokenType: 'Bearer'
};
[mockSubscription1.id, mockSubscription2.id, mockTenantId].forEach(s => {
mockTokens[s] = {
token: 'mock_token',
tokenType: 'Bearer'
};
});
const mockCredential = new TokenCredentials(mockTokens[mockTenantId].token, mockTokens[mockTenantId].tokenType);
const mockCredential = new TokenCredentials(mockToken.token, mockToken.tokenType);
let mockSubscriptionCache: azureResource.AzureResourceSubscription[] = [];
@@ -94,7 +91,6 @@ describe('AzureResourceAccountTreeNode.info', function (): void {
mockCacheService = TypeMoq.Mock.ofType<IAzureResourceCacheService>();
mockSubscriptionService = TypeMoq.Mock.ofType<IAzureResourceSubscriptionService>();
mockSubscriptionFilterService = TypeMoq.Mock.ofType<IAzureResourceSubscriptionFilterService>();
mockTenantService = TypeMoq.Mock.ofType<IAzureResourceTenantService>();
mockTreeChangeHandler = TypeMoq.Mock.ofType<IAzureResourceTreeChangeHandler>();
@@ -104,13 +100,11 @@ describe('AzureResourceAccountTreeNode.info', function (): void {
mockAppContext.registerService<IAzureResourceCacheService>(AzureResourceServiceNames.cacheService, mockCacheService.object);
mockAppContext.registerService<IAzureResourceSubscriptionService>(AzureResourceServiceNames.subscriptionService, mockSubscriptionService.object);
mockAppContext.registerService<IAzureResourceSubscriptionFilterService>(AzureResourceServiceNames.subscriptionFilterService, mockSubscriptionFilterService.object);
mockAppContext.registerService<IAzureResourceTenantService>(AzureResourceServiceNames.tenantService, mockTenantService.object);
getSecurityTokenStub = sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens));
getSecurityTokenStub = sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken));
mockCacheService.setup((o) => o.generateKey(TypeMoq.It.isAnyString())).returns(() => generateGuid());
mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache);
mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache = mockSubscriptions);
mockTenantService.setup((o) => o.getTenantId(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockTenantId));
});
afterEach(function (): void {
@@ -138,7 +132,7 @@ describe('AzureResourceAccountTreeNode.info', function (): void {
});
it('Should be correct when there are subscriptions listed.', async function (): Promise<void> {
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(undefined));
const accountTreeNodeLabel = `${mockAccount.displayInfo.displayName} (${mockSubscriptions.length} / ${mockSubscriptions.length} subscriptions)`;
@@ -158,7 +152,7 @@ describe('AzureResourceAccountTreeNode.info', function (): void {
});
it('Should be correct when there are subscriptions filtered.', async function (): Promise<void> {
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(mockFilteredSubscriptions));
const accountTreeNodeLabel = `${mockAccount.displayInfo.displayName} (${mockFilteredSubscriptions.length} / ${mockSubscriptions.length} subscriptions)`;
@@ -184,7 +178,6 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
mockCacheService = TypeMoq.Mock.ofType<IAzureResourceCacheService>();
mockSubscriptionService = TypeMoq.Mock.ofType<IAzureResourceSubscriptionService>();
mockSubscriptionFilterService = TypeMoq.Mock.ofType<IAzureResourceSubscriptionFilterService>();
mockTenantService = TypeMoq.Mock.ofType<IAzureResourceTenantService>();
mockTreeChangeHandler = TypeMoq.Mock.ofType<IAzureResourceTreeChangeHandler>();
@@ -194,13 +187,11 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
mockAppContext.registerService<IAzureResourceCacheService>(AzureResourceServiceNames.cacheService, mockCacheService.object);
mockAppContext.registerService<IAzureResourceSubscriptionService>(AzureResourceServiceNames.subscriptionService, mockSubscriptionService.object);
mockAppContext.registerService<IAzureResourceSubscriptionFilterService>(AzureResourceServiceNames.subscriptionFilterService, mockSubscriptionFilterService.object);
mockAppContext.registerService<IAzureResourceTenantService>(AzureResourceServiceNames.tenantService, mockTenantService.object);
sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens));
sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken));
mockCacheService.setup((o) => o.generateKey(TypeMoq.It.isAnyString())).returns(() => generateGuid());
mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache);
mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache = mockSubscriptions);
mockTenantService.setup((o) => o.getTenantId(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockTenantId));
});
afterEach(function (): void {
@@ -208,14 +199,14 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
});
it('Should load subscriptions from scratch and update cache when it is clearing cache.', async function (): Promise<void> {
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve([]));
const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object);
const children = await accountTreeNode.getChildren();
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential), TypeMoq.Times.once());
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(0));
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockSubscriptionFilterService.verify((o) => o.getSelectedSubscriptions(mockAccount), TypeMoq.Times.once());
@@ -241,7 +232,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
});
it('Should load subscriptions from cache when it is not clearing cache.', async function (): Promise<void> {
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(undefined));
const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object);
@@ -250,7 +241,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
const children = await accountTreeNode.getChildren();
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential), TypeMoq.Times.once());
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once());
@@ -262,7 +253,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
});
it('Should handle when there is no subscriptions.', async function (): Promise<void> {
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(undefined));
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(undefined));
const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object);
@@ -278,7 +269,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
});
it('Should honor subscription filtering.', async function (): Promise<void> {
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(mockFilteredSubscriptions));
const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object);
@@ -296,7 +287,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
});
it('Should handle errors.', async function (): Promise<void> {
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions));
mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions));
const mockError = 'Test error';
mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => { throw new Error(mockError); });
@@ -305,8 +296,8 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void {
const children = await accountTreeNode.getChildren();
should(getSecurityTokenStub.calledOnce).be.true('getSecurityToken should have been called exactly once');
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential), TypeMoq.Times.once());
should(getSecurityTokenStub.calledTwice).be.true('getSecurityToken should have been called exactly twice - once per subscription');
mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId), TypeMoq.Times.once());
mockSubscriptionFilterService.verify((o) => o.getSelectedSubscriptions(mockAccount), TypeMoq.Times.once());
mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never());
mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once());
@@ -325,7 +316,6 @@ describe('AzureResourceAccountTreeNode.clearCache', function (): void {
mockCacheService = TypeMoq.Mock.ofType<IAzureResourceCacheService>();
mockSubscriptionService = TypeMoq.Mock.ofType<IAzureResourceSubscriptionService>();
mockSubscriptionFilterService = TypeMoq.Mock.ofType<IAzureResourceSubscriptionFilterService>();
mockTenantService = TypeMoq.Mock.ofType<IAzureResourceTenantService>();
mockTreeChangeHandler = TypeMoq.Mock.ofType<IAzureResourceTreeChangeHandler>();
@@ -335,13 +325,11 @@ describe('AzureResourceAccountTreeNode.clearCache', function (): void {
mockAppContext.registerService<IAzureResourceCacheService>(AzureResourceServiceNames.cacheService, mockCacheService.object);
mockAppContext.registerService<IAzureResourceSubscriptionService>(AzureResourceServiceNames.subscriptionService, mockSubscriptionService.object);
mockAppContext.registerService<IAzureResourceSubscriptionFilterService>(AzureResourceServiceNames.subscriptionFilterService, mockSubscriptionFilterService.object);
mockAppContext.registerService<IAzureResourceTenantService>(AzureResourceServiceNames.tenantService, mockTenantService.object);
sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens));
sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken));
mockCacheService.setup((o) => o.generateKey(TypeMoq.It.isAnyString())).returns(() => generateGuid());
mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache);
mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache = mockSubscriptions);
mockTenantService.setup((o) => o.getTenantId(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockTenantId));
});
afterEach(function (): void {

View File

@@ -43,13 +43,14 @@ const mockAccount: azdata.Account = {
isStale: false
};
const mockTenantId: string = 'mock_tenant';
const mockSubscription: azureResource.AzureResourceSubscription = {
id: 'mock_subscription',
name: 'mock subscription'
name: 'mock subscription',
tenant: mockTenantId
};
const mockTenantId: string = 'mock_tenant';
let mockResourceTreeDataProvider1: TypeMoq.IMock<azureResource.IAzureResourceTreeDataProvider>;
let mockResourceProvider1: TypeMoq.IMock<azureResource.IAzureResourceProvider>;

View File

@@ -7,17 +7,13 @@ export class Logger {
private static _piiLogging: boolean = false;
static log(msg: any, ...vals: any[]) {
if (vals && vals.length > 0) {
return console.log(msg, vals);
}
console.log(msg);
const fullMessage = `${msg} - ${vals.map(v => JSON.stringify(v)).join(' - ')}`;
console.log(fullMessage);
}
static error(msg: any, ...vals: any[]) {
if (vals && vals.length > 0) {
return console.error(msg, vals);
}
console.error(msg);
const fullMessage = `${msg} - ${vals.map(v => JSON.stringify(v)).join(' - ')}`;
console.error(fullMessage);
}
static pii(msg: any, ...vals: any[]) {

View File

@@ -70,22 +70,22 @@ export class AccountFeature implements StaticFeature {
account = accountList[0];
}
const securityToken: { [key: string]: any } = await azdata.accounts.getSecurityToken(account, azdata.AzureResource.AzureKeyVault);
const tenant = account.properties.tenants.find((t: { [key: string]: string }) => request.authority.includes(t.id));
const unauthorizedMessage = localize('mssql.insufficientlyPrivelagedAzureAccount', "The configured Azure account for {0} does not have sufficient permissions for Azure Key Vault to access a column master key for Always Encrypted.", account.key.accountId);
if (!tenant) {
window.showErrorMessage(unauthorizedMessage);
return undefined;
}
let tokenBundle = securityToken[tenant.id];
if (!tokenBundle) {
const securityToken = await azdata.accounts.getAccountSecurityToken(account, tenant, azdata.AzureResource.AzureKeyVault);
if (!securityToken?.token) {
window.showErrorMessage(unauthorizedMessage);
return undefined;
}
let params: contracts.RequestSecurityTokenResponse = {
accountKey: JSON.stringify(account.key),
token: securityToken[tenant.id].token
token: securityToken.token
};
return params;

10
src/sql/azdata.d.ts vendored
View File

@@ -2131,9 +2131,18 @@ declare module 'azdata' {
* @param account Account to generate security token for (defaults to
* AzureResource.ResourceManagement if not given)
* @return Promise to return the security token
* @deprecated use getAccountSecurityToken
*/
export function getSecurityToken(account: Account, resource?: AzureResource): Thenable<{ [key: string]: any }>;
/**
* Generates a security token by asking the account's provider
* @param account
* @param tenant
* @param resource
*/
export function getAccountSecurityToken(account: Account, tenant: string, resource: AzureResource): Thenable<{ token: string, tokenType?: string } | undefined>;
/**
* An [event](#Event) which fires when the accounts have changed.
*/
@@ -2279,6 +2288,7 @@ declare module 'azdata' {
* @param account The account to generate a security token for
* @param resource The resource to get the token for
* @return Promise to return a security token object
* @deprecated use getAccountSecurityToken
*/
getSecurityToken(account: Account, resource: AzureResource): Thenable<{} | undefined>;

View File

@@ -494,4 +494,14 @@ declare module 'azdata' {
name?: string;
}
export interface AccountProvider {
/**
* Generates a security token for the provided account and tenant
* @param account The account to generate a security token for
* @param resource The resource to get the token for
* @return Promise to return a security token object
*/
getAccountSecurityToken(account: Account, tenant: string, resource: AzureResource): Thenable<{ token: string } | undefined>;
}
}

View File

@@ -21,7 +21,11 @@ export interface IAccountManagementService {
getAccountProviderMetadata(): Thenable<azdata.AccountProviderMetadata[]>;
getAccountsForProvider(providerId: string): Thenable<azdata.Account[]>;
getAccounts(): Thenable<azdata.Account[]>;
/**
* @deprecated
*/
getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{ [key: string]: { token: string } }>;
getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }>;
removeAccount(accountKey: azdata.AccountKey): Thenable<boolean>;
removeAccounts(): Thenable<boolean>;
refreshAccount(account: azdata.Account): Thenable<azdata.Account>;

View File

@@ -55,6 +55,10 @@ export class TestAccountManagementService implements IAccountManagementService {
return Promise.resolve([]);
}
getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> {
return Promise.resolve(undefined);
}
removeAccount(accountKey: azdata.AccountKey): Thenable<boolean> {
throw new Error('Method not implemented');
}
@@ -100,6 +104,9 @@ export class AccountProviderStub implements azdata.AccountProvider {
return Promise.resolve({});
}
getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> {
return Promise.resolve(undefined);
}
initialize(storedAccounts: azdata.Account[]): Thenable<azdata.Account[]> {
return Promise.resolve(storedAccounts);
}

View File

@@ -75,9 +75,13 @@ export class MainThreadAccountManagement extends Disposable implements MainThrea
clear(accountKey: azdata.AccountKey): Thenable<void> {
return self._proxy.$clear(handle, accountKey);
},
getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{}> {
return self._proxy.$getSecurityToken(account, resource);
},
getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> {
return self._proxy.$getAccountSecurityToken(account, tenant, resource);
},
initialize(restoredAccounts: azdata.Account[]): Thenable<azdata.Account[]> {
return self._proxy.$initialize(handle, restoredAccounts);
},

View File

@@ -89,10 +89,7 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape {
return Promise.all(promises).then(() => resultAccounts);
}
public $getSecurityToken(account: azdata.Account, resource?: azdata.AzureResource): Thenable<{}> {
if (resource === undefined) {
resource = AzureResource.ResourceManagement;
}
public $getSecurityToken(account: azdata.Account, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{}> {
return this.$getAllAccounts().then(() => {
for (const handle in this._accounts) {
const providerHandle = parseInt(handle);
@@ -105,6 +102,20 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape {
});
}
public $getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{ token: string }> {
return this.$getAllAccounts().then(() => {
for (const handle in this._accounts) {
const providerHandle = parseInt(handle);
if (firstIndex(this._accounts[handle], (acct) => acct.key.accountId === account.key.accountId) !== -1) {
return this._withProvider(providerHandle, (provider: azdata.AccountProvider) => provider.getAccountSecurityToken(account, tenant, resource));
}
}
throw new Error(`Account ${account.key.accountId} not found.`);
});
}
public get onDidChangeAccounts(): Event<azdata.DidChangeAccountsParams> {
return this._onDidChangeAccounts.event;
}

View File

@@ -160,6 +160,9 @@ export function createAdsApiFactory(accessor: ServicesAccessor): IAdsExtensionAp
getSecurityToken(account: azdata.Account, resource?: azdata.AzureResource): Thenable<{}> {
return extHostAccountManagement.$getSecurityToken(account, resource);
},
getAccountSecurityToken(account: azdata.Account, tenant: string, resource?: azdata.AzureResource): Thenable<{ token: string }> {
return extHostAccountManagement.$getAccountSecurityToken(account, tenant, resource);
},
onDidChangeAccounts(listener: (e: azdata.DidChangeAccountsParams) => void, thisArgs?: any, disposables?: extHostTypes.Disposable[]) {
return extHostAccountManagement.onDidChangeAccounts(listener, thisArgs, disposables);
}

View File

@@ -31,6 +31,7 @@ export abstract class ExtHostAccountManagementShape {
$autoOAuthCancelled(handle: number): Thenable<void> { throw ni(); }
$clear(handle: number, accountKey: azdata.AccountKey): Thenable<void> { throw ni(); }
$getSecurityToken(account: azdata.Account, resource?: azdata.AzureResource): Thenable<{}> { throw ni(); }
$getAccountSecurityToken(account: azdata.Account, tenant: string, resource?: azdata.AzureResource): Thenable<{ token: string }> { throw ni(); }
$initialize(handle: number, restoredAccounts: azdata.Account[]): Thenable<azdata.Account[]> { throw ni(); }
$prompt(handle: number): Thenable<azdata.Account | azdata.PromptFailedResult> { throw ni(); }
$refresh(handle: number, account: azdata.Account): Thenable<azdata.Account | azdata.PromptFailedResult> { throw ni(); }

View File

@@ -244,6 +244,19 @@ export class AccountManagementService implements IAccountManagementService {
});
}
/**
* Generates a security token by asking the account's provider
* @param account Account to generate security token for
* @param tenant Tenant to generate security token for
* @param resource The resource to get the security token for
* @return Promise to return the security token
*/
public getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> {
return this.doWithProvider(account.key.providerId, provider => {
return provider.provider.getAccountSecurityToken(account, tenant, resource);
});
}
/**
* Removes an account from the account store and clears sensitive data in the provider
* @param accountKey Key for the account to remove

View File

@@ -813,8 +813,8 @@ export class ConnectionManagementService extends Disposable implements IConnecti
const accounts = await this._accountManagementService.getAccounts();
const azureAccounts = accounts.filter(a => a.key.providerId.startsWith('azure'));
if (azureAccounts && azureAccounts.length > 0) {
let accountName = (connection.authenticationType === Constants.azureMFA || connection.authenticationType === Constants.azureMFAAndUser) ? connection.azureAccount : connection.userName;
let account = find(azureAccounts, account => account.key.accountId === accountName);
let accountId = (connection.authenticationType === Constants.azureMFA || connection.authenticationType === Constants.azureMFAAndUser) ? connection.azureAccount : connection.userName;
let account = find(azureAccounts, account => account.key.accountId === accountId);
if (account) {
this._logService.debug(`Getting security token for Azure account ${account.key.accountId}`);
if (account.isStale) {
@@ -827,26 +827,17 @@ export class ConnectionManagementService extends Disposable implements IConnecti
return false;
}
}
const tokensByTenant = await this._accountManagementService.getSecurityToken(account, azureResource);
this._logService.debug(`Got tokens for tenants [${Object.keys(tokensByTenant).join(',')}]`);
let token: string;
const tenantId = connection.azureTenantId;
if (tenantId && tokensByTenant[tenantId]) {
token = tokensByTenant[tenantId].token;
} else {
this._logService.debug(`No security token found for specific tenant ${tenantId} - falling back to first one`);
const tokens = values(tokensByTenant);
if (tokens.length === 0) {
this._logService.info(`No security tokens found for account`);
return false;
}
token = tokens[0].token;
const token = await this._accountManagementService.getAccountSecurityToken(account, tenantId, azureResource);
this._logService.debug(`Got token for tenant ${token}`);
if (!token) {
this._logService.info(`No security tokens found for account`);
}
connection.options['azureAccountToken'] = token;
connection.options['azureAccountToken'] = token.token;
connection.options['password'] = '';
return true;
} else {
this._logService.info(`Could not find Azure account with name ${accountName}`);
this._logService.info(`Could not find Azure account with name ${accountId}`);
}
} else {
this._logService.info(`Could not find any Azure accounts from accounts : [${accounts.map(a => `${a.key.accountId} (${a.key.providerId})`).join(',')}]`);

View File

@@ -6,7 +6,7 @@
import 'vs/css!./media/sqlConnection';
import { Button } from 'sql/base/browser/ui/button/button';
import { SelectBox } from 'sql/base/browser/ui/selectBox/selectBox';
import { SelectBox, SelectOptionItemSQL } from 'sql/base/browser/ui/selectBox/selectBox';
import { Checkbox } from 'sql/base/browser/ui/checkbox/checkbox';
import { InputBox } from 'sql/base/browser/ui/inputBox/inputBox';
import * as DialogHelper from 'sql/workbench/browser/modal/dialogHelper';
@@ -520,12 +520,18 @@ export class ConnectionWidget extends lifecycle.Disposable {
let oldSelection = this._azureAccountDropdown.value;
const accounts = await this._accountManagementService.getAccounts();
this._azureAccountList = accounts.filter(a => a.key.providerId.startsWith('azure'));
let accountDropdownOptions = this._azureAccountList.map(account => account.displayInfo.displayName);
let accountDropdownOptions: SelectOptionItemSQL[] = this._azureAccountList.map(account => {
return {
text: account.displayInfo.displayName,
value: account.key.accountId
} as SelectOptionItemSQL;
});
if (accountDropdownOptions.length === 0) {
// If there are no accounts add a blank option so that add account isn't automatically selected
accountDropdownOptions.unshift('');
accountDropdownOptions.unshift({ text: '', value: '' });
}
accountDropdownOptions.push(this._addAzureAccountMessage);
accountDropdownOptions.push({ text: this._addAzureAccountMessage, value: this._addAzureAccountMessage });
this._azureAccountDropdown.setOptions(accountDropdownOptions);
this._azureAccountDropdown.selectWithOptionName(oldSelection);
}

View File

@@ -1299,6 +1299,7 @@ suite('SQL ConnectionManagementService tests', () => {
let servername = 'test-database.database.windows.net';
azureConnectionProfile.serverName = servername;
let providerId = 'azure_PublicCloud';
azureConnectionProfile.azureTenantId = 'testTenant';
// Set up the account management service to return a token for the given user
accountManagementService.setup(x => x.getAccountsForProvider(TypeMoq.It.isAny())).returns(providerId => Promise.resolve<azdata.Account[]>([
@@ -1327,10 +1328,9 @@ suite('SQL ConnectionManagementService tests', () => {
]);
});
let testToken = 'testToken';
accountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
azure_publicCloud: {
token: testToken
}
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
token: testToken,
tokenType: 'Bearer'
}));
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({
profile: profile,
@@ -1384,11 +1384,8 @@ suite('SQL ConnectionManagementService tests', () => {
]);
});
let testToken = 'testToken';
let returnedTokens = {};
returnedTokens['azure_publicCloud'] = { token: 'badToken' };
returnedTokens[azureTenantId] = { token: testToken };
accountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(returnedTokens));
let returnedToken = { token: 'testToken', tokenType: 'Bearer' };
accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(returnedToken));
connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({
profile: profile,
savedCred: false
@@ -1399,7 +1396,7 @@ suite('SQL ConnectionManagementService tests', () => {
// Then the returned profile has the account token set corresponding to the requested tenant
assert.equal(profileWithCredentials.userName, azureConnectionProfile.userName);
assert.equal(profileWithCredentials.options['azureAccountToken'], testToken);
assert.equal(profileWithCredentials.options['azureAccountToken'], returnedToken.token);
});
test('getConnections test', () => {

View File

@@ -60,7 +60,12 @@ export class FirewallRuleDialogController {
private async handleOnCreateFirewallRule(): Promise<void> {
const resourceProviderId = this._resourceProviderId;
try {
const securityTokenMappings = await this._accountManagementService.getSecurityToken(this._firewallRuleDialog.viewModel.selectedAccount!, AzureResource.ResourceManagement);
const tenantId = this._connection.azureTenantId;
const token = await this._accountManagementService.getAccountSecurityToken(this._firewallRuleDialog.viewModel.selectedAccount!, tenantId, AzureResource.ResourceManagement);
const securityTokenMappings = {
[tenantId]: token
};
const firewallRuleInfo: azdata.FirewallRuleInfo = {
startIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.fromSubnetIPRange,
endIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.toSubnetIPRange,

View File

@@ -92,7 +92,8 @@ suite('Firewall rule dialog controller tests', () => {
providerName: mssqlProviderName,
options: {},
saveProfile: true,
id: ''
id: '',
azureTenantId: 'someTenant'
};
});
@@ -137,7 +138,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and call create firewall rule in resource provider
await deferredPromise;
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockFirewallRuleDialog.verify(x => x.close(), TypeMoq.Times.once());
mockFirewallRuleDialog.verify(x => x.onServiceComplete(), TypeMoq.Times.once());
@@ -164,7 +165,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and an error dialog should have been opened
await deferredPromise;
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.never());
});
@@ -191,7 +192,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and an error dialog should have been opened
await deferredPromise;
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
@@ -217,7 +218,7 @@ suite('Firewall rule dialog controller tests', () => {
// Then: it should get security token from account management service and an error dialog should have been opened
await deferredPromise;
mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
@@ -226,8 +227,8 @@ suite('Firewall rule dialog controller tests', () => {
function getMockAccountManagementService(resolveSecurityToken: boolean): TypeMoq.Mock<TestAccountManagementService> {
let accountManagementTestService = new TestAccountManagementService();
let mockAccountManagementService = TypeMoq.Mock.ofInstance(accountManagementTestService);
mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()))
.returns(() => resolveSecurityToken ? Promise.resolve({}) : Promise.reject(null));
mockAccountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()))
.returns(() => resolveSecurityToken ? Promise.resolve({ token: 'token' }) : Promise.reject(null));
return mockAccountManagementService;
}

View File

@@ -442,6 +442,8 @@ function getMockAccountManagementService(accounts: azdata.Account[]): TypeMoq.Mo
.returns(() => Promise.resolve(accounts));
mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isValue(accounts[0]), TypeMoq.It.isAny()))
.returns(() => Promise.resolve({}));
mockAccountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isValue(accounts[0]), TypeMoq.It.isAny(), TypeMoq.It.isAny()))
.returns(() => Promise.resolve(undefined));
mockAccountManagementService.setup(x => x.updateAccountListEvent)
.returns(() => () => { return undefined; });