diff --git a/extensions/azurecore/src/account-provider/auths/azureAuth.ts b/extensions/azurecore/src/account-provider/auths/azureAuth.ts index 90ca6d1471..f2c2cf1223 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuth.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuth.ts @@ -3,104 +3,39 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import * as azdata from 'azdata'; import * as vscode from 'vscode'; +import * as azdata from 'azdata'; + import * as nls from 'vscode-nls'; -import axios, { AxiosResponse, AxiosRequestConfig } from 'axios'; -import * as qs from 'qs'; -import * as url from 'url'; import { - AzureAccountProviderMetadata, - Tenant, AzureAccount, - Resource, + AzureAccountProviderMetadata, AzureAuthType, - Subscription, - Deferred + Deferred, + Resource, + Tenant } from '../interfaces'; +import * as url from 'url'; import { SimpleTokenCache } from '../simpleTokenCache'; import { MemoryDatabase } from '../utils/memoryDatabase'; +import axios, { AxiosRequestConfig, AxiosResponse } from 'axios'; import { Logger } from '../../utils/Logger'; +import * as qs from 'qs'; +import { AzureAuthError } from './azureAuthError'; + const localize = nls.loadMessageBundle(); -export interface AccountKey { - /** - * Account Key - uniquely identifies an account - */ - key: string -} - -export interface AccessToken extends AccountKey { - /** - * Access Token - */ - token: string; -} - -export interface RefreshToken extends AccountKey { - /** - * Refresh Token - */ - token: string; - - /** - * Account Key - */ - key: string -} - -export interface TokenResponse { - [tenantId: string]: Token -} - -export interface Token extends AccountKey { - /** - * Access token - */ - token: string; - - /** - * TokenType - */ - tokenType: string; -} - -export interface TokenClaims { // https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens - aud: string; - iss: string; - iat: number; - idp: string, - nbf: number; - exp: number; - c_hash: string; - at_hash: string; - aio: string; - preferred_username: string; - email: string; - name: string; - nonce: string; - oid: string; - roles: string[]; - rh: string; - sub: string; - tid: string; - unique_name: string; - uti: string; - ver: string; -} - -export type TokenRefreshResponse = { accessToken: AccessToken, refreshToken: RefreshToken, tokenClaims: TokenClaims, expiresOn: string }; export abstract class AzureAuth implements vscode.Disposable { - protected readonly memdb = new MemoryDatabase(); + protected readonly memdb = new MemoryDatabase(); protected readonly WorkSchoolAccountType: string = 'work_school'; protected readonly MicrosoftAccountType: string = 'microsoft'; protected readonly loginEndpointUrl: string; - protected readonly commonTenant: Tenant; + public readonly commonTenant: Tenant; protected readonly redirectUri: string; protected readonly scopes: string[]; protected readonly scopesString: string; @@ -141,185 +76,227 @@ export abstract class AzureAuth implements vscode.Disposable { this.scopesString = this.scopes.join(' '); } - public abstract async login(): Promise; - - public abstract async autoOAuthCancelled(): Promise; - - public abstract async promptForConsent(resourceId: string, tenant: string): Promise<{ tokenRefreshResponse: TokenRefreshResponse, authCompleteDeferred: Deferred } | undefined>; - - public dispose() { } - - public async refreshAccess(oldAccount: azdata.Account): Promise { - const response = await this.getCachedToken(oldAccount.key); - if (!response) { - oldAccount.isStale = true; - return oldAccount; - } - - const refreshToken = response.refreshToken; - if (!refreshToken || !refreshToken.key) { - oldAccount.isStale = true; - return oldAccount; - } - + public async startLogin(): Promise { + let loginComplete: Deferred; try { - // Refresh the access token - const tokenResponse = await this.refreshAccessToken(oldAccount.key, refreshToken); - const tenants = await this.getTenants(tokenResponse.accessToken); - - // Recreate account object - const newAccount = this.createAccount(tokenResponse.tokenClaims, tokenResponse.accessToken.key, tenants); - - const subscriptions = await this.getSubscriptions(newAccount); - newAccount.properties.subscriptions = subscriptions; - - return newAccount; + const result = await this.login(this.commonTenant, this.metadata.settings.microsoftResource); + loginComplete = result.authComplete; + if (!result?.response) { + Logger.error('Authentication failed'); + return { + canceled: false + }; + } + const account = await this.hydrateAccount(result.response.accessToken, result.response.tokenClaims); + loginComplete?.resolve(); + return account; } catch (ex) { - oldAccount.isStale = true; - if (ex.message) { - await vscode.window.showErrorMessage(ex.message); + if (ex instanceof AzureAuthError) { + if (loginComplete) { + loginComplete.reject(ex.getPrintableString()); + } else { + vscode.window.showErrorMessage(ex.getPrintableString()); + } } Logger.error(ex); + return undefined; } - return oldAccount; } + private getHomeTenant(account: AzureAccount): Tenant { + // Home is defined by the API + // Lets pick the home tenant - and fall back to commonTenant if they don't exist + return account.properties.tenants.find(t => t.tenantCategory === 'Home') ?? account.properties.tenants[0] ?? this.commonTenant; + } - public async getSecurityToken(account: azdata.Account, azureResource: azdata.AzureResource): Promise { + public async refreshAccess(account: AzureAccount): Promise { + try { + const tenant = this.getHomeTenant(account); + const tokenResult = await this.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.MicrosoftResourceManagement); + if (!tokenResult) { + account.isStale = true; + return account; + } + + return await this.hydrateAccount(tokenResult, this.getTokenClaims(tokenResult.token)); + } catch (ex) { + if (ex instanceof AzureAuthError) { + vscode.window.showErrorMessage(ex.getPrintableString()); + } + Logger.error(ex); + account.isStale = true; + return account; + } + } + + public async hydrateAccount(token: Token | AccessToken, tokenClaims: TokenClaims): Promise { + const tenants = await this.getTenants({ ...token }); + const account = this.createAccount(tokenClaims, token.key, tenants); + return account; + } + + public async getAccountSecurityToken(account: AzureAccount, tenantId: string, azureResource: azdata.AzureResource): Promise { if (account.isStale === true) { - Logger.log('Account was stale, no tokens being fetched'); + Logger.log('Account was stale. No tokens being fetched.'); return undefined; } const resource = this.resources.find(s => s.azureResourceId === azureResource); if (!resource) { + Logger.log('Invalid resource, not fetching', azureResource); + return undefined; } - const azureAccount = account as AzureAccount; - const response: TokenResponse = {}; + const tenant = account.properties.tenants.find(t => t.id === tenantId); - for (const tenant of azureAccount.properties.tenants) { - let cachedTokens = await this.getCachedToken(account.key, resource.id, tenant.id); - // Check expiration - if (cachedTokens) { - const expiresOn = Number(this.memdb.get(this.createMemdbString(account.key.accountId, tenant.id, resource.id))); - const currentTime = new Date().getTime() / 1000; + if (!tenant) { + throw new AzureAuthError(localize('azure.tenantNotFound', "Specifed tenant with ID '{0}' not found.", tenantId), `Tenant ${tenantId} not found.`, undefined); + } - if (!Number.isNaN(expiresOn)) { - const remainingTime = expiresOn - currentTime; - const fiveMinutes = 5 * 60; - // If the remaining time is less than five minutes, assume the token has expired. It's too close to expiration to be meaningful. - if (remainingTime < fiveMinutes) { - cachedTokens = undefined; - } - } else { - // No expiration date, assume expired. - cachedTokens = undefined; - Logger.log('Assuming expired token due to no expiration date - this is expected on first launch.'); - } + const cachedTokens = await this.getSavedToken(tenant, resource, account.key); + // Let's check to see if we can just use the cached tokens to return to the user + if (cachedTokens?.accessToken) { + let expiry = Number(cachedTokens.expiresOn); + if (Number.isNaN(expiry)) { + Logger.log('Expiration time was not defined. This is expected on first launch'); + expiry = 0; } + const currentTime = new Date().getTime() / 1000; - // Refresh - if (!cachedTokens) { + let accessToken = cachedTokens.accessToken; + const remainingTime = expiry - currentTime; + const maxTolerance = 2 * 60; // two minutes - const baseToken = await this.getCachedToken(account.key); - if (!baseToken) { - account.isStale = true; - Logger.log('Base token was empty, account is stale.'); - return undefined; - } - - try { - await this.refreshAccessToken(account.key, baseToken.refreshToken, tenant, resource); - } catch (ex) { - Logger.log(`Could not refresh access token for ${JSON.stringify(tenant)} - silently removing the tenant from the user's account.`); - Logger.error(`Actual error: ${JSON.stringify(ex?.response?.data ?? ex.message ?? ex, undefined, 2)}`); - azureAccount.properties.tenants = azureAccount.properties.tenants.filter(t => t.id !== tenant.id); - continue; - } - - cachedTokens = await this.getCachedToken(account.key, resource.id, tenant.id); - if (!cachedTokens) { - Logger.log('Refresh access tokens didn not set cache'); - return undefined; - } + if (remainingTime < maxTolerance) { + const result = await this.refreshToken(tenant, resource, cachedTokens.refreshToken); + accessToken = result.accessToken; } - const { accessToken } = cachedTokens; - response[tenant.id] = { - token: accessToken.token, - key: accessToken.key, + // Let's just return here. + if (accessToken) { + return { + ...accessToken, + tokenType: 'Bearer' + }; + } + } + + // User didn't have any cached tokens, or the cached tokens weren't useful. + // For most users we can use the refresh token from the general microsoft resource to an access token of basically any type of resource we want. + const baseTokens = await this.getSavedToken(this.commonTenant, this.metadata.settings.microsoftResource, account.key); + if (!baseTokens) { + Logger.error('User had no base tokens for the basic resource registered. This should not happen and indicates something went wrong with the authentication cycle'); + const msg = localize('azure.noBaseToken', 'Something failed with the authentication, or your tokens have been deleted from the system. Please try adding your account to Azure Data Studio again.'); + account.isStale = true; + throw new AzureAuthError(msg, 'No base token found', undefined); + } + // Let's try to convert the access token type, worst case we'll have to prompt the user to do an interactive authentication. + const result = await this.refreshToken(tenant, resource, baseTokens.refreshToken); + if (result.accessToken) { + return { + ...result.accessToken, tokenType: 'Bearer' }; } - - if (azureAccount.properties.subscriptions) { - azureAccount.properties.subscriptions.forEach(subscription => { - // Make sure that tenant has information populated. - if (response[subscription.tenantId]) { - response[subscription.id] = { - ...response[subscription.tenantId] - }; - } - }); - } - - return response; + return undefined; } - public async clearCredentials(account: azdata.AccountKey): Promise { - try { - return this.deleteAccountCache(account); - } catch (ex) { - const msg = localize('azure.cacheErrrorRemove', "Error when removing your account from the cache."); - vscode.window.showErrorMessage(msg); - Logger.error('Error when removing tokens.', ex); - } - } - protected toBase64UrlEncoding(base64string: string): string { - return base64string.replace(/=/g, '').replace(/\+/g, '-').replace(/\//g, '_'); // Need to use base64url encoding - } - protected async makePostRequest(uri: string, postData: { [key: string]: string }, validateStatus = false) { - try { - const config: AxiosRequestConfig = { - headers: { - 'Content-Type': 'application/x-www-form-urlencoded' - } + protected abstract async login(tenant: Tenant, resource: Resource): Promise<{ response: OAuthTokenResponse, authComplete: Deferred }>; + + /** + * Refreshes a token, if a refreshToken is passed in then we use that. If it is not passed in then we will prompt the user for consent. + * @param tenant + * @param resource + * @param refreshToken + */ + public async refreshToken(tenant: Tenant, resource: Resource, refreshToken: RefreshToken | undefined): Promise { + if (refreshToken) { + const postData: RefreshTokenPostData = { + grant_type: 'refresh_token', + client_id: this.clientId, + refresh_token: refreshToken.token, + tenant: tenant.id, + resource: resource.endpoint }; - if (validateStatus) { - config.validateStatus = () => true; - } - - return await axios.post(uri, qs.stringify(postData), config); - } catch (ex) { - Logger.log('Unexpected error making Azure auth request', 'azureCore.postRequest', JSON.stringify(ex?.response?.data, undefined, 2)); - throw ex; + return this.getToken(tenant, resource, postData); } + + return this.handleInteractionRequired(tenant, resource); } - protected async makeGetRequest(token: string, uri: string): Promise> { - try { - const config = { - headers: { - Authorization: `Bearer ${token}`, - 'Content-Type': 'application/json', - }, + public async getToken(tenant: Tenant, resource: Resource, postData: AuthorizationCodePostData | TokenPostData | RefreshTokenPostData): Promise { + const tokenUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/token`; + const response = await this.makePostRequest(tokenUrl, postData); + + if (response.data.error === 'interaction_required') { + return this.handleInteractionRequired(tenant, resource); + } + + if (response.data.error) { + Logger.error('Response error!', response.data); + throw new AzureAuthError(localize('azure.responseError', "Token retrival failed with an error. Open developer tools to view the error"), 'Token retrival failed', undefined); + } + + const accessTokenString = response.data.access_token; + const refreshTokenString = response.data.refresh_token; + const expiresOnString = response.data.expires_on; + + return this.getTokenHelper(tenant, resource, accessTokenString, refreshTokenString, expiresOnString); + } + + public async getTokenHelper(tenant: Tenant, resource: Resource, accessTokenString: string, refreshTokenString: string, expiresOnString: string): Promise { + if (!accessTokenString) { + const msg = localize('azure.accessTokenEmpty', 'No access token returned from Microsoft OAuth'); + throw new AzureAuthError(msg, 'Access token was empty', undefined); + } + + const tokenClaims: TokenClaims = this.getTokenClaims(accessTokenString); + + const userKey = tokenClaims.sub ?? tokenClaims.oid; + + if (!userKey) { + const msg = localize('azure.noUniqueIdentifier', "The user had no unique identifier within AAD"); + throw new AzureAuthError(msg, 'No unique identifier', undefined); + } + + const accessToken: AccessToken = { + token: accessTokenString, + key: userKey + }; + let refreshToken: RefreshToken; + + if (refreshTokenString) { + refreshToken = { + token: refreshTokenString, + key: userKey }; - - return await axios.get(uri, config); - } catch (ex) { - // Intercept and print error - Logger.log('Unexpected error making Azure auth request', 'azureCore.getRequest', JSON.stringify(ex?.response?.data, undefined, 2)); - // rethrow error - throw ex; } + + const result: OAuthTokenResponse = { + accessToken, + refreshToken, + tokenClaims, + expiresOn: expiresOnString + }; + + const accountKey: azdata.AccountKey = { + providerId: this.metadata.id, + accountId: userKey + }; + + await this.saveToken(tenant, resource, accountKey, result); + + return result; } - protected async getTenants(token: AccessToken): Promise { + + + //#region tenant calls + public async getTenants(token: AccessToken): Promise { interface TenantResponse { // https://docs.microsoft.com/en-us/rest/api/resources/tenants/list id: string tenantId: string @@ -329,7 +306,7 @@ export abstract class AzureAuth implements vscode.Disposable { const tenantUri = url.resolve(this.metadata.settings.armResource.endpoint, 'tenants?api-version=2019-11-01'); try { - const tenantResponse = await this.makeGetRequest(token.token, tenantUri); + const tenantResponse = await this.makeGetRequest(tenantUri, token.token); Logger.pii('getTenants', tenantResponse.data); const tenants: Tenant[] = tenantResponse.data.value.map((tenantInfo: TenantResponse) => { return { @@ -353,97 +330,88 @@ export abstract class AzureAuth implements vscode.Disposable { } } - protected async getSubscriptions(account: AzureAccount): Promise { - interface SubscriptionResponse { // https://docs.microsoft.com/en-us/rest/api/resources/subscriptions/list - subscriptionId: string - tenantId: string - displayName: string - } - const allSubs: Subscription[] = []; - const tokens = await this.getSecurityToken(account, azdata.AzureResource.ResourceManagement); - if (!tokens) { - Logger.log('There were no resource management tokens to retrieve subscriptions from. Account is stale.'); - account.isStale = true; - } + //#endregion - for (const tenant of account.properties.tenants) { - const token = tokens[tenant.id]; - const subscriptionUri = url.resolve(this.metadata.settings.armResource.endpoint, 'subscriptions?api-version=2019-11-01'); - try { - const subscriptionResponse = await this.makeGetRequest(token.token, subscriptionUri); - Logger.pii('getSubscriptions', subscriptionResponse.data); - const subscriptions: Subscription[] = subscriptionResponse.data.value.map((subscriptionInfo: SubscriptionResponse) => { - return { - id: subscriptionInfo.subscriptionId, - displayName: subscriptionInfo.displayName, - tenantId: subscriptionInfo.tenantId - } as Subscription; - }); - allSubs.push(...subscriptions); - } catch (ex) { - Logger.error(ex); - throw new Error('Error retrieving subscription information'); - } + //#region token management + private async saveToken(tenant: Tenant, resource: Resource, accountKey: azdata.AccountKey, { accessToken, refreshToken, expiresOn }: OAuthTokenResponse) { + const msg = localize('azure.cacheErrorAdd', "Error when adding your account to the cache."); + if (!tenant.id || !resource.id) { + Logger.pii('Tenant ID or resource ID was undefined', tenant, resource); + throw new AzureAuthError(msg, 'Adding account to cache failed', undefined); } - return allSubs; - } - - protected async getToken(postData: { [key: string]: string }, tenant = this.commonTenant, resourceId: string = '', resourceEndpoint: string = ''): Promise { try { - let refreshResponse: TokenRefreshResponse; - - try { - const tokenUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/token`; - const tokenResponse = await this.makePostRequest(tokenUrl, postData); - Logger.pii(JSON.stringify(tokenResponse.data)); - const tokenClaims = this.getTokenClaims(tokenResponse.data.access_token); - - const accessToken: AccessToken = { - token: tokenResponse.data.access_token, - key: tokenClaims.oid ?? tokenClaims.email ?? tokenClaims.unique_name ?? tokenClaims.name, - }; - - const refreshToken: RefreshToken = { - token: tokenResponse.data.refresh_token, - key: accessToken.key - }; - const expiresOn = tokenResponse.data.expires_on; - - refreshResponse = { accessToken, refreshToken, tokenClaims, expiresOn }; - } catch (ex) { - Logger.pii(JSON.stringify(ex?.response?.data)); - if (ex?.response?.data?.error === 'interaction_required') { - const shouldOpenLink = await this.openConsentDialog(tenant, resourceId); - if (shouldOpenLink === true) { - const { tokenRefreshResponse, authCompleteDeferred } = await this.promptForConsent(resourceEndpoint, tenant.id); - refreshResponse = tokenRefreshResponse; - authCompleteDeferred.resolve(); - } else { - vscode.window.showInformationMessage(localize('azure.noConsentToReauth', "The authentication failed since Azure Data Studio was unable to open re-authentication page.")); - } - } else { - return undefined; - } - } - - this.memdb.set(this.createMemdbString(refreshResponse.accessToken.key, tenant.id, resourceId), refreshResponse.expiresOn); - return refreshResponse; - } catch (err) { - const msg = localize('azure.noToken', "Retrieving the Azure token failed. Please sign in again."); - vscode.window.showErrorMessage(msg); - throw new Error(err); + await this.tokenCache.saveCredential(`${accountKey.accountId}_access_${resource.id}_${tenant.id}`, JSON.stringify(accessToken)); + await this.tokenCache.saveCredential(`${accountKey.accountId}_refresh_${resource.id}_${tenant.id}`, JSON.stringify(refreshToken)); + this.memdb.set(`${accountKey.accountId}_${tenant.id}_${resource.id}`, expiresOn); + } catch (ex) { + Logger.error(ex); + throw new AzureAuthError(msg, 'Adding account to cache failed', ex); } } - public async openConsentDialog(tenant: Tenant, resourceId: string): Promise { + public async getSavedToken(tenant: Tenant, resource: Resource, accountKey: azdata.AccountKey): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> { + const getMsg = localize('azure.cacheErrorGet', "Error when getting your account from the cache"); + const parseMsg = localize('azure.cacheErrorParse', "Error when parsing your account from the cache"); + + if (!tenant.id || !resource.id) { + Logger.pii('Tenant ID or resource ID was undefined', tenant, resource); + throw new AzureAuthError(getMsg, 'Getting account from cache failed', undefined); + } + + let accessTokenString: string; + let refreshTokenString: string; + let expiresOn: string; + try { + accessTokenString = await this.tokenCache.getCredential(`${accountKey.accountId}_access_${resource.id}_${tenant.id}`); + refreshTokenString = await this.tokenCache.getCredential(`${accountKey.accountId}_refresh_${resource.id}_${tenant.id}`); + expiresOn = this.memdb.get(`${accountKey.accountId}_${tenant.id}_${resource.id}`); + } catch (ex) { + Logger.error(ex); + throw new AzureAuthError(getMsg, 'Getting account from cache failed', ex); + } + + try { + if (!accessTokenString) { + return undefined; + } + const accessToken: AccessToken = JSON.parse(accessTokenString); + let refreshToken: RefreshToken; + if (refreshTokenString) { + refreshToken = JSON.parse(refreshTokenString); + } + + return { + accessToken, refreshToken, expiresOn + }; + } catch (ex) { + Logger.error(ex); + throw new AzureAuthError(parseMsg, 'Parsing account from cache failed', ex); + } + } + //#endregion + + //#region interaction handling + + public async handleInteractionRequired(tenant: Tenant, resource: Resource): Promise { + const shouldOpen = await this.askUserForInteraction(tenant, resource); + if (shouldOpen) { + const result = await this.login(tenant, resource); + result?.authComplete?.resolve(); + return result?.response; + } + return undefined; + } + + /** + * Asks the user if they would like to do the interaction based authentication as required by OAuth2 + * @param tenant + * @param resource + */ + private async askUserForInteraction(tenant: Tenant, resource: Resource): Promise { if (!tenant.displayName && !tenant.id) { throw new Error('Tenant did not have display name or id'); } - if (tenant.id === 'common') { - throw new Error('Common tenant should not need consent'); - } - const getTenantConfigurationSet = (): Set => { const configuration = vscode.workspace.getConfiguration('azure.tenant.config'); let values: string[] = configuration.get('filter') ?? []; @@ -486,7 +454,7 @@ export abstract class AzureAuth implements vscode.Disposable { } }; - const messageBody = localize('azurecore.consentDialog.body', "Your tenant '{0} ({1})' requires you to re-authenticate again to access {2} resources. Press Open to start the authentication process.", tenant.displayName, tenant.id, resourceId); + const messageBody = localize('azurecore.consentDialog.body', "Your tenant '{0} ({1})' requires you to re-authenticate again to access {2} resources. Press Open to start the authentication process.", tenant.displayName, tenant.id, resource.id); const result = await vscode.window.showInformationMessage(messageBody, { modal: true }, openItem, closeItem, dontAskAgainItem); if (result.action) { @@ -495,113 +463,9 @@ export abstract class AzureAuth implements vscode.Disposable { return result.booleanResult; } + //#endregion - protected getTokenClaims(accessToken: string): TokenClaims | undefined { - try { - const split = accessToken.split('.'); - return JSON.parse(Buffer.from(split[1], 'base64').toString('binary')); - } catch (ex) { - throw new Error('Unable to read token claims: ' + JSON.stringify(ex)); - } - } - - private async refreshAccessToken(account: azdata.AccountKey, rt: RefreshToken, tenant: Tenant = this.commonTenant, resource?: Resource): Promise { - const postData: { [key: string]: string } = { - grant_type: 'refresh_token', - refresh_token: rt.token, - client_id: this.clientId, - tenant: tenant.id, - }; - - if (resource) { - postData.resource = resource.endpoint; - } - - const getTokenResponse = await this.getToken(postData, tenant, resource?.id, resource?.endpoint); - - const accessToken = getTokenResponse?.accessToken; - const refreshToken = getTokenResponse?.refreshToken; - - if (!accessToken || !refreshToken) { - Logger.log('Access or refresh token were undefined'); - const msg = localize('azure.refreshTokenError', "Error when refreshing your account."); - throw new Error(msg); - } - - await this.setCachedToken(account, accessToken, refreshToken, resource?.id, tenant?.id); - - return getTokenResponse; - } - - - public async setCachedToken(account: azdata.AccountKey, accessToken: AccessToken, refreshToken: RefreshToken, resourceId?: string, tenantId?: string): Promise { - const msg = localize('azure.cacheErrorAdd', "Error when adding your account to the cache."); - resourceId = resourceId ?? ''; - tenantId = tenantId ?? ''; - if (!accessToken || !accessToken.token || !refreshToken.token || !accessToken.key) { - throw new Error(msg); - } - - try { - await this.tokenCache.saveCredential(`${account.accountId}_access_${resourceId}_${tenantId}`, JSON.stringify(accessToken)); - await this.tokenCache.saveCredential(`${account.accountId}_refresh_${resourceId}_${tenantId}`, JSON.stringify(refreshToken)); - } catch (ex) { - Logger.error('Error when storing tokens.', ex); - throw new Error(msg); - } - } - - public async getCachedToken(account: azdata.AccountKey, resourceId?: string, tenantId?: string): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken } | undefined> { - resourceId = resourceId ?? ''; - tenantId = tenantId ?? ''; - - let accessToken: AccessToken; - let refreshToken: RefreshToken; - try { - accessToken = JSON.parse(await this.tokenCache.getCredential(`${account.accountId}_access_${resourceId}_${tenantId}`)); - refreshToken = JSON.parse(await this.tokenCache.getCredential(`${account.accountId}_refresh_${resourceId}_${tenantId}`)); - } catch (ex) { - return undefined; - } - - if (!accessToken || !refreshToken) { - return undefined; - } - - if (!refreshToken.token || !refreshToken.key) { - return undefined; - } - - if (!accessToken.token || !accessToken.key) { - return undefined; - } - - return { - accessToken, - refreshToken - }; - - } - - public createMemdbString(accountKey: string, tenantId: string, resourceId: string): string { - return `${accountKey}_${tenantId}_${resourceId}`; - } - - public async deleteAccountCache(account: azdata.AccountKey): Promise { - const results = await this.tokenCache.findCredentials(account.accountId); - - for (let { account } of results) { - await this.tokenCache.clearCredential(account); - } - } - - public async deleteAllCache(): Promise { - const results = await this.tokenCache.findCredentials(''); - - for (let { account } of results) { - await this.tokenCache.clearCredential(account); - } - } + //#region data modeling public createAccount(tokenClaims: TokenClaims, key: string, tenants: Tenant[]): AzureAccount { // Determine if this is a microsoft account @@ -663,4 +527,184 @@ export abstract class AzureAuth implements vscode.Disposable { return account; } + + //#endregion + + //#region network functions + public async makePostRequest(url: string, postData: AuthorizationCodePostData | TokenPostData | DeviceCodeStartPostData | DeviceCodeCheckPostData): Promise> { + const config: AxiosRequestConfig = { + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + validateStatus: () => true // Never throw + }; + + // Intercept response and print out the response for future debugging + const response = await axios.post(url, qs.stringify(postData), config); + Logger.pii(url, postData, response.data); + return response; + } + + private async makeGetRequest(url: string, token: string): Promise> { + const config: AxiosRequestConfig = { + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}` + }, + validateStatus: () => true // Never throw + }; + + const response = await axios.get(url, config); + Logger.pii(url, response.data); + return response; + } + + //#endregion + + //#region inconsequential + protected getTokenClaims(accessToken: string): TokenClaims | undefined { + try { + const split = accessToken.split('.'); + return JSON.parse(Buffer.from(split[1], 'base64').toString('binary')); + } catch (ex) { + throw new Error('Unable to read token claims: ' + JSON.stringify(ex)); + } + } + + protected toBase64UrlEncoding(base64string: string): string { + return base64string.replace(/=/g, '').replace(/\+/g, '-').replace(/\//g, '_'); // Need to use base64url encoding + } + + public async deleteAllCache(): Promise { + const results = await this.tokenCache.findCredentials(''); + + for (let { account } of results) { + await this.tokenCache.clearCredential(account); + } + } + + public async clearCredentials(account: azdata.AccountKey): Promise { + try { + return this.deleteAccountCache(account); + } catch (ex) { + const msg = localize('azure.cacheErrrorRemove', "Error when removing your account from the cache."); + vscode.window.showErrorMessage(msg); + Logger.error('Error when removing tokens.', ex); + } + } + + public async deleteAccountCache(account: azdata.AccountKey): Promise { + const results = await this.tokenCache.findCredentials(account.accountId); + + for (let { account } of results) { + await this.tokenCache.clearCredential(account); + } + } + + public async dispose() { } + + public async autoOAuthCancelled(): Promise { } + + //#endregion } + +//#region models + +export interface AccountKey { + /** + * Account Key - uniquely identifies an account + */ + key: string +} + +export interface AccessToken extends AccountKey { + /** + * Access Token + */ + token: string; +} + +export interface RefreshToken extends AccountKey { + /** + * Refresh Token + */ + token: string; + + /** + * Account Key + */ + key: string +} + +export interface MultiTenantTokenResponse { + [tenantId: string]: Token +} + +export interface Token extends AccountKey { + /** + * Access token + */ + token: string; + + /** + * TokenType + */ + tokenType: string; +} + +export interface TokenClaims { // https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens + aud: string; + iss: string; + iat: number; + idp: string, + nbf: number; + exp: number; + c_hash: string; + at_hash: string; + aio: string; + preferred_username: string; + email: string; + name: string; + nonce: string; + oid: string; + roles: string[]; + rh: string; + sub: string; + tid: string; + unique_name: string; + uti: string; + ver: string; +} + +export type OAuthTokenResponse = { accessToken: AccessToken, refreshToken: RefreshToken, tokenClaims: TokenClaims, expiresOn: string }; + +export interface TokenPostData { + grant_type: 'refresh_token' | 'authorization_code' | 'urn:ietf:params:oauth:grant-type:device_code'; + client_id: string; + resource: string; +} + +export interface RefreshTokenPostData extends TokenPostData { + grant_type: 'refresh_token'; + refresh_token: string; + client_id: string; + tenant: string +} + +export interface AuthorizationCodePostData extends TokenPostData { + grant_type: 'authorization_code'; + code: string; + code_verifier: string; + redirect_uri: string; +} + +export interface DeviceCodeStartPostData extends Omit { + +} + +export interface DeviceCodeCheckPostData extends Omit { + grant_type: 'urn:ietf:params:oauth:grant-type:device_code', + tenant: string, + code: string +} +//#endregion diff --git a/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts b/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts index 90e106cca8..e0138f5154 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuthCodeGrant.ts @@ -3,49 +3,37 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import * as azdata from 'azdata'; +import { AuthorizationCodePostData, AzureAuth, OAuthTokenResponse } from './azureAuth'; +import { AzureAccountProviderMetadata, AzureAuthType, Deferred, Resource, Tenant } from '../interfaces'; import * as vscode from 'vscode'; import * as crypto from 'crypto'; +import { SimpleTokenCache } from '../simpleTokenCache'; +import { SimpleWebServer } from '../utils/simpleWebServer'; +import { AzureAuthError } from './azureAuthError'; +import { Logger } from '../../utils/Logger'; import * as nls from 'vscode-nls'; -import { promises as fs } from 'fs'; import * as path from 'path'; import * as http from 'http'; import * as qs from 'qs'; +import { promises as fs } from 'fs'; -import { - AzureAuth, - AccessToken, - RefreshToken, - TokenClaims, - TokenRefreshResponse, -} from './azureAuth'; - -import { - AzureAccountProviderMetadata, - AzureAuthType, - Deferred -} from '../interfaces'; - -import { SimpleWebServer } from '../utils/simpleWebServer'; -import { SimpleTokenCache } from '../simpleTokenCache'; -import { Logger } from '../../utils/Logger'; const localize = nls.loadMessageBundle(); -function parseQuery(uri: vscode.Uri) { - return uri.query.split('&').reduce((prev: any, current) => { - const queryString = current.split('='); - prev[queryString[0]] = queryString[1]; - return prev; - }, {}); +interface AuthCodeResponse { + authCode: string; + codeVerifier: string; + redirectUri: string; } -interface AuthCodeResponse { - authCode: string, - codeVerifier: string +interface CryptoValues { + nonce: string; + codeVerifier: string; + codeChallenge: string; } + export class AzureAuthCodeGrant extends AzureAuth { - private static readonly USER_FRIENDLY_NAME: string = localize('azure.azureAuthCodeGrantName', "Azure Auth Code Grant"); + private static readonly USER_FRIENDLY_NAME: string = localize('azure.azureAuthCodeGrantName', 'Azure Auth Code Grant'); private server: SimpleWebServer; constructor( @@ -57,105 +45,52 @@ export class AzureAuthCodeGrant extends AzureAuth { super(metadata, tokenCache, context, uriEventEmitter, AzureAuthType.AuthCodeGrant, AzureAuthCodeGrant.USER_FRIENDLY_NAME); } - public async promptForConsent(resourceEndpoint: string, tenant: string = this.commonTenant.id): Promise<{ tokenRefreshResponse: TokenRefreshResponse, authCompleteDeferred: Deferred } | undefined> { + + protected async login(tenant: Tenant, resource: Resource): Promise<{ response: OAuthTokenResponse, authComplete: Deferred }> { let authCompleteDeferred: Deferred; let authCompletePromise = new Promise((resolve, reject) => authCompleteDeferred = { resolve, reject }); - let authResponse: AuthCodeResponse; + if (vscode.env.uiKind === vscode.UIKind.Web) { - authResponse = await this.loginWithoutLocalServer(resourceEndpoint, tenant); + authResponse = await this.loginWeb(tenant, resource); } else { - authResponse = await this.loginWithLocalServer(authCompletePromise, resourceEndpoint, tenant); - } - - let tokenClaims: TokenClaims; - let accessToken: AccessToken; - let refreshToken: RefreshToken; - let expiresOn: string; - - try { - const { accessToken: at, refreshToken: rt, tokenClaims: tc, expiresOn: eo } = await this.getTokenWithAuthCode(authResponse.authCode, authResponse.codeVerifier, this.redirectUri); - tokenClaims = tc; - accessToken = at; - refreshToken = rt; - expiresOn = eo; - } catch (ex) { - if (ex.msg) { - vscode.window.showErrorMessage(ex.msg); - } - Logger.error(ex); - } - - if (!accessToken) { - const msg = localize('azure.tokenFail', "Failure when retrieving tokens."); - authCompleteDeferred.reject(new Error(msg)); - throw Error('Failure when retrieving tokens'); + authResponse = await this.loginDesktop(tenant, resource, authCompletePromise); } return { - tokenRefreshResponse: { accessToken, refreshToken, tokenClaims, expiresOn }, - authCompleteDeferred + response: await this.getTokenWithAuthorizationCode(tenant, resource, authResponse), + authComplete: authCompleteDeferred }; } - public async autoOAuthCancelled(): Promise { - return this.server.shutdown(); - } - - public async loginWithLocalServer(authCompletePromise: Promise, resourceId: string, tenant: string = this.commonTenant.id): Promise { - this.server = new SimpleWebServer(); - const nonce = crypto.randomBytes(16).toString('base64'); - let serverPort: string; - - try { - serverPort = await this.server.startup(); - } catch (err) { - const msg = localize('azure.serverCouldNotStart', 'Server could not start. This could be a permissions error or an incompatibility on your system. You can try enabling device code authentication from settings.'); - vscode.window.showErrorMessage(msg); - Logger.error(JSON.stringify(err)); - return undefined; - } - - // The login code to use - let loginUrl: string; - let codeVerifier: string; - { - codeVerifier = this.toBase64UrlEncoding(crypto.randomBytes(32).toString('base64')); - const state = `${serverPort},${encodeURIComponent(nonce)}`; - const codeChallenge = this.toBase64UrlEncoding(crypto.createHash('sha256').update(codeVerifier).digest('base64')); - const loginQuery = { - response_type: 'code', - response_mode: 'query', - client_id: this.clientId, - redirect_uri: this.redirectUri, - state, - prompt: 'select_account', - code_challenge_method: 'S256', - code_challenge: codeChallenge, - resource: resourceId - }; - loginUrl = `${this.loginEndpointUrl}${tenant}/oauth2/authorize?${qs.stringify(loginQuery)}`; - } - - await vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${serverPort}/signin?nonce=${encodeURIComponent(nonce)}`)); - - const authCode = await this.addServerListeners(this.server, nonce, loginUrl, authCompletePromise); - - return { - authCode, - codeVerifier + /** + * Requests an OAuthTokenResponse from Microsoft OAuth + * + * @param tenant + * @param resource + * @param authCode + * @param redirectUri + * @param codeVerifier + */ + private async getTokenWithAuthorizationCode(tenant: Tenant, resource: Resource, { authCode, redirectUri, codeVerifier }: AuthCodeResponse): Promise { + const postData: AuthorizationCodePostData = { + grant_type: 'authorization_code', + code: authCode, + client_id: this.clientId, + code_verifier: codeVerifier, + redirect_uri: redirectUri, + resource: resource.endpoint }; + + return this.getToken(tenant, resource, postData); } - public async loginWithoutLocalServer(resourceId: string, tenant: string = this.commonTenant.id): Promise { + private async loginWeb(tenant: Tenant, resource: Resource): Promise { const callbackUri = await vscode.env.asExternalUri(vscode.Uri.parse(`${vscode.env.uriScheme}://microsoft.azurecore`)); - const nonce = crypto.randomBytes(16).toString('base64'); + const { nonce, codeVerifier, codeChallenge } = this.createCryptoValues(); const port = (callbackUri.authority.match(/:([0-9]*)$/) || [])[1] || (callbackUri.scheme === 'https' ? 443 : 80); const state = `${port},${encodeURIComponent(nonce)},${encodeURIComponent(callbackUri.query)}`; - const codeVerifier = this.toBase64UrlEncoding(crypto.randomBytes(32).toString('base64')); - const codeChallenge = this.toBase64UrlEncoding(crypto.createHash('sha256').update(codeVerifier).digest('base64')); - const loginQuery = { response_type: 'code', response_mode: 'query', @@ -165,26 +100,27 @@ export class AzureAuthCodeGrant extends AzureAuth { prompt: 'select_account', code_challenge_method: 'S256', code_challenge: codeChallenge, - resource: resourceId + resource: resource.id }; const signInUrl = `${this.loginEndpointUrl}${tenant}/oauth2/authorize?${qs.stringify(loginQuery)}`; await vscode.env.openExternal(vscode.Uri.parse(signInUrl)); - const authCode = await this.handleCodeResponse(state); + const authCode = await this.handleWebResponse(state); return { authCode, - codeVerifier + codeVerifier, + redirectUri: this.redirectUri }; } - public async handleCodeResponse(state: string): Promise { + private async handleWebResponse(state: string): Promise { let uriEventListener: vscode.Disposable; return new Promise((resolve: (value: any) => void, reject) => { uriEventListener = this.uriEventEmitter.event(async (uri: vscode.Uri) => { try { - const query = parseQuery(uri); + const query = this.parseQuery(uri); const code = query.code; if (query.state !== state && decodeURIComponent(query.state) !== state) { reject(new Error('State mismatch')); @@ -200,33 +136,46 @@ export class AzureAuthCodeGrant extends AzureAuth { }); } - public async login(): Promise { - const { tokenRefreshResponse, authCompleteDeferred } = await this.promptForConsent(this.metadata.settings.signInResourceId); - const { accessToken, refreshToken, tokenClaims } = tokenRefreshResponse; + private parseQuery(uri: vscode.Uri): { [key: string]: string } { + return uri.query.split('&').reduce((prev: any, current) => { + const queryString = current.split('='); + prev[queryString[0]] = queryString[1]; + return prev; + }, {}); + } - const tenants = await this.getTenants(accessToken); + private async loginDesktop(tenant: Tenant, resource: Resource, authCompletePromise: Promise): Promise { + this.server = new SimpleWebServer(); + let serverPort: string; try { - await this.setCachedToken({ accountId: accessToken.key, providerId: this.metadata.id }, accessToken, refreshToken); + serverPort = await this.server.startup(); } catch (ex) { - Logger.error(ex); - if (ex.msg) { - vscode.window.showErrorMessage(ex.msg); - authCompleteDeferred.reject(ex); - } else { - authCompleteDeferred.reject(new Error('There was an issue when storing the cache.')); - } - - return { canceled: false } as azdata.PromptFailedResult; + const msg = localize('azure.serverCouldNotStart', 'Server could not start. This could be a permissions error or an incompatibility on your system. You can try enabling device code authentication from settings.'); + throw new AzureAuthError(msg, 'Server could not start', ex); } + const { nonce, codeVerifier, codeChallenge } = this.createCryptoValues(); + const state = `${serverPort},${encodeURIComponent(nonce)}`; + const loginQuery = { + response_type: 'code', + response_mode: 'query', + client_id: this.clientId, + redirect_uri: this.redirectUri, + state, + prompt: 'select_account', + code_challenge_method: 'S256', + code_challenge: codeChallenge, + resource: resource.endpoint + }; + const loginUrl = `${this.loginEndpointUrl}${tenant.id}/oauth2/authorize?${qs.stringify(loginQuery)}`; + await vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${serverPort}/signin?nonce=${encodeURIComponent(nonce)}`)); + const authCode = await this.addServerListeners(this.server, nonce, loginUrl, authCompletePromise); + return { + authCode, + codeVerifier, + redirectUri: this.redirectUri + }; - const account = this.createAccount(tokenClaims, accessToken.key, tenants); - - const subscriptions = await this.getSubscriptions(account); - account.properties.subscriptions = subscriptions; - - authCompleteDeferred.resolve(); - return account; } private async addServerListeners(server: SimpleWebServer, nonce: string, loginUrl: string, authComplete: Promise): Promise { @@ -266,7 +215,7 @@ export class AzureAuthCodeGrant extends AzureAuth { if (receivedNonce !== nonce) { res.writeHead(400, { 'content-type': 'text/html' }); - res.write(localize('azureAuth.nonceError', "Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again.")); + res.write(localize('azureAuth.nonceError', 'Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again.')); res.end(); Logger.error('nonce no match', receivedNonce, nonce); return; @@ -283,7 +232,7 @@ export class AzureAuthCodeGrant extends AzureAuth { const stateSplit = state.split(','); if (stateSplit.length !== 2) { res.writeHead(400, { 'content-type': 'text/html' }); - res.write(localize('azureAuth.stateError', "Authentication failed due to a state mismatch, please close ADS and try again.")); + res.write(localize('azureAuth.stateError', 'Authentication failed due to a state mismatch, please close ADS and try again.')); res.end(); reject(new Error('State mismatch')); return; @@ -291,7 +240,7 @@ export class AzureAuthCodeGrant extends AzureAuth { if (stateSplit[1] !== encodeURIComponent(nonce)) { res.writeHead(400, { 'content-type': 'text/html' }); - res.write(localize('azureAuth.nonceError', "Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again.")); + res.write(localize('azureAuth.nonceError', 'Authentication failed due to a nonce mismatch, please close Azure Data Studio and try again.')); res.end(); reject(new Error('Nonce mismatch')); return; @@ -310,20 +259,14 @@ export class AzureAuthCodeGrant extends AzureAuth { }); } - private async getTokenWithAuthCode(authCode: string, codeVerifier: string, redirectUri: string): Promise { - const postData = { - grant_type: 'authorization_code', - code: authCode, - client_id: this.clientId, - code_verifier: codeVerifier, - redirect_uri: redirectUri, - resource: this.metadata.settings.signInResourceId + + private createCryptoValues(): CryptoValues { + const nonce = crypto.randomBytes(16).toString('base64'); + const codeVerifier = this.toBase64UrlEncoding(crypto.randomBytes(32).toString('base64')); + const codeChallenge = this.toBase64UrlEncoding(crypto.createHash('sha256').update(codeVerifier).digest('base64')); + + return { + nonce, codeVerifier, codeChallenge }; - - return this.getToken(postData); - } - - public dispose() { - this.server?.shutdown().catch(console.error); } } diff --git a/extensions/azurecore/src/account-provider/auths/azureAuthError.ts b/extensions/azurecore/src/account-provider/auths/azureAuthError.ts new file mode 100644 index 0000000000..5234bbba5f --- /dev/null +++ b/extensions/azurecore/src/account-provider/auths/azureAuthError.ts @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +export class AzureAuthError extends Error { + private readonly _originalMessage: string; + + constructor(localizedMessage: string, _originalMessage: string, private readonly originalException: any) { + super(localizedMessage); + + } + + get originalMessage(): string { + return this._originalMessage; + } + + getPrintableString(): string { + return JSON.stringify({ + originalMessage: this.originalMessage, + originalException: this.originalException + }, undefined, 2); + } +} diff --git a/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts b/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts index 252e15de0b..14dc71a7d5 100644 --- a/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts +++ b/extensions/azurecore/src/account-provider/auths/azureDeviceCode.ts @@ -8,21 +8,24 @@ import * as vscode from 'vscode'; import * as nls from 'vscode-nls'; import { AzureAuth, - TokenClaims, - AccessToken, - RefreshToken, + OAuthTokenResponse, + DeviceCodeStartPostData, + DeviceCodeCheckPostData, } from './azureAuth'; import { AzureAccountProviderMetadata, - AzureAccount, AzureAuthType, + Tenant, + Resource, + Deferred, // Tenant, // Subscription } from '../interfaces'; import { SimpleTokenCache } from '../simpleTokenCache'; +import { Logger } from '../../utils/Logger'; const localize = nls.loadMessageBundle(); interface DeviceCodeLogin { // https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-device-code @@ -43,7 +46,6 @@ interface DeviceCodeLoginResult { } export class AzureDeviceCode extends AzureAuth { - private static readonly USER_FRIENDLY_NAME: string = localize('azure.azureDeviceCodeAuth', "Azure Device Code"); private readonly pageTitle: string; constructor( @@ -56,60 +58,42 @@ export class AzureDeviceCode extends AzureAuth { this.pageTitle = localize('addAccount', "Add {0} account", this.metadata.displayName); } + protected async login(tenant: Tenant, resource: Resource): Promise<{ response: OAuthTokenResponse, authComplete: Deferred }> { + let authCompleteDeferred: Deferred; + let authCompletePromise = new Promise((resolve, reject) => authCompleteDeferred = { resolve, reject }); - public async promptForConsent(resourceId: string, tenant: string = this.commonTenant.id): Promise { - vscode.window.showErrorMessage(localize('azure.deviceCodeDoesNotSupportConsent', "Device code authentication does not support prompting for consent. Switch the authentication method in settings to code grant.")); - return undefined; + const uri = `${this.loginEndpointUrl}/${this.commonTenant.id}/oauth2/devicecode`; + const postData: DeviceCodeStartPostData = { + client_id: this.clientId, + resource: resource.endpoint + }; + + const postResult = await this.makePostRequest(uri, postData); + + const initialDeviceLogin: DeviceCodeLogin = postResult.data; + + await azdata.accounts.beginAutoOAuthDeviceCode(this.metadata.id, this.pageTitle, initialDeviceLogin.message, initialDeviceLogin.user_code, initialDeviceLogin.verification_url); + + const finalDeviceLogin = await this.setupPolling(initialDeviceLogin); + + const accessTokenString = finalDeviceLogin.access_token; + const refreshTokenString = finalDeviceLogin.refresh_token; + + const currentTime = new Date().getTime() / 1000; + const expiresOn = `${currentTime + finalDeviceLogin.expires_in}`; + + const result = await this.getTokenHelper(tenant, resource, accessTokenString, refreshTokenString, expiresOn); + this.closeOnceComplete(authCompletePromise).catch(Logger.error); + + return { + response: result, + authComplete: authCompleteDeferred + }; } - public async login(): Promise { - try { - const uri = `${this.loginEndpointUrl}/${this.commonTenant}/oauth2/devicecode`; - const postResult = await this.makePostRequest(uri, { - client_id: this.clientId, - resource: this.metadata.settings.signInResourceId - }); - - const initialDeviceLogin: DeviceCodeLogin = postResult.data; - - await azdata.accounts.beginAutoOAuthDeviceCode(this.metadata.id, this.pageTitle, initialDeviceLogin.message, initialDeviceLogin.user_code, initialDeviceLogin.verification_url); - - const finalDeviceLogin = await this.setupPolling(initialDeviceLogin); - - let tokenClaims: TokenClaims; - let accessToken: AccessToken; - let refreshToken: RefreshToken; - // let tenants: Tenant[]; - // let subscriptions: Subscription[]; - - tokenClaims = this.getTokenClaims(finalDeviceLogin.access_token); - - accessToken = { - token: finalDeviceLogin.access_token, - key: tokenClaims.email || tokenClaims.unique_name || tokenClaims.name, - }; - - refreshToken = { - token: finalDeviceLogin.refresh_token, - key: accessToken.key, - }; - - await this.setCachedToken({ accountId: accessToken.key, providerId: this.metadata.id }, accessToken, refreshToken); - - const tenants = await this.getTenants(accessToken); - const account = this.createAccount(tokenClaims, accessToken.key, tenants); - const subscriptions = await this.getSubscriptions(account); - account.properties.subscriptions = subscriptions; - return account; - } catch (ex) { - console.log(ex); - if (ex.msg) { - vscode.window.showErrorMessage(ex.msg); - } - return { canceled: false }; - } finally { - azdata.accounts.endAutoOAuthDeviceCode(); - } + private async closeOnceComplete(promise: Promise): Promise { + await promise; + azdata.accounts.endAutoOAuthDeviceCode(); } @@ -141,14 +125,14 @@ export class AzureDeviceCode extends AzureAuth { const msg = localize('azure.deviceCodeCheckFail', "Error encountered when trying to check for login results"); try { const uri = `${this.loginEndpointUrl}/${this.commonTenant}/oauth2/token`; - const postData = { + const postData: DeviceCodeCheckPostData = { grant_type: 'urn:ietf:params:oauth:grant-type:device_code', client_id: this.clientId, tenant: this.commonTenant.id, code: info.device_code }; - const postResult = await this.makePostRequest(uri, postData, true); + const postResult = await this.makePostRequest(uri, postData); const result: DeviceCodeLoginResult = postResult.data; @@ -164,5 +148,4 @@ export class AzureDeviceCode extends AzureAuth { public async autoOAuthCancelled(): Promise { return azdata.accounts.endAutoOAuthDeviceCode(); } - } diff --git a/extensions/azurecore/src/account-provider/azureAccountProvider.ts b/extensions/azurecore/src/account-provider/azureAccountProvider.ts index 03d382c3f7..1f2471a394 100644 --- a/extensions/azurecore/src/account-provider/azureAccountProvider.ts +++ b/extensions/azurecore/src/account-provider/azureAccountProvider.ts @@ -10,14 +10,15 @@ import * as nls from 'vscode-nls'; import { AzureAccountProviderMetadata, AzureAuthType, - Deferred + Deferred, + AzureAccount } from './interfaces'; import { SimpleTokenCache } from './simpleTokenCache'; -import { AzureAuth, TokenResponse } from './auths/azureAuth'; +import { Logger } from '../utils/Logger'; +import { MultiTenantTokenResponse, Token, AzureAuth } from './auths/azureAuth'; import { AzureAuthCodeGrant } from './auths/azureAuthCodeGrant'; import { AzureDeviceCode } from './auths/azureDeviceCode'; -import { Logger } from '../utils/Logger'; const localize = nls.loadMessageBundle(); @@ -101,14 +102,29 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp } - getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable { + getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable { return this._getSecurityToken(account, resource); } - private async _getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Promise { + getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable { + return this._getAccountSecurityToken(account, tenant, resource); + } + + private async _getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Promise { await this.initCompletePromise; const azureAuth = this.getAuthMethod(undefined); - return azureAuth?.getSecurityToken(account, resource); + return azureAuth?.getAccountSecurityToken(account, tenant, resource); + } + + private async _getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Promise { + vscode.window.showInformationMessage(localize('azure.deprecatedGetSecurityToken', "A call was made to azdata.accounts.getSecurityToken, this method is deprecated and will be removed in future releases. Please use getAccountSecurityToken instead.")); + const azureAccount = account as AzureAccount; + const response: MultiTenantTokenResponse = {}; + for (const tenant of azureAccount.properties.tenants) { + response[tenant.id] = await this._getAccountSecurityToken(account, tenant.id, resource); + } + + return response; } prompt(): Thenable { @@ -134,7 +150,7 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp } if (this.authMappings.size === 1) { - return this.getAuthMethod(undefined).login(); + return this.getAuthMethod(undefined).startLogin(); } const options: Option[] = []; @@ -150,7 +166,7 @@ export class AzureAccountProvider implements azdata.AccountProvider, vscode.Disp return { canceled: true }; } - return pick.azureAuth.login(); + return pick.azureAuth.startLogin(); } refresh(account: azdata.Account): Thenable { diff --git a/extensions/azurecore/src/account-provider/interfaces.ts b/extensions/azurecore/src/account-provider/interfaces.ts index 7e7915bc2b..1977afbd4a 100644 --- a/extensions/azurecore/src/account-provider/interfaces.ts +++ b/extensions/azurecore/src/account-provider/interfaces.ts @@ -64,11 +64,6 @@ interface Settings { */ clientId?: string; - /** - * Identifier of the resource to request when signing in - */ - signInResourceId?: string; - /** * Information that describes the Microsoft resource management resource */ @@ -177,10 +172,6 @@ interface AzureAccountProperties { */ tenants: Tenant[]; - /** - * A list of subscriptions the user belongs to - */ - subscriptions?: Subscription[]; } export interface Subscription { diff --git a/extensions/azurecore/src/account-provider/providerSettings.ts b/extensions/azurecore/src/account-provider/providerSettings.ts index d8968712f3..f14654028d 100644 --- a/extensions/azurecore/src/account-provider/providerSettings.ts +++ b/extensions/azurecore/src/account-provider/providerSettings.ts @@ -17,7 +17,6 @@ const publicAzureSettings: ProviderSettings = { settings: { host: 'https://login.microsoftonline.com/', clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255', - signInResourceId: 'https://management.core.windows.net/', microsoftResource: { id: 'marm', endpoint: 'https://management.core.windows.net/', @@ -72,7 +71,6 @@ const usGovAzureSettings: ProviderSettings = { settings: { host: 'https://login.microsoftonline.us/', clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255', - signInResourceId: 'https://management.core.usgovcloudapi.net/', microsoftResource: { id: 'marm', endpoint: 'https://management.core.usgovcloudapi.net/', @@ -121,7 +119,6 @@ const usNatAzureSettings: ProviderSettings = { settings: { host: 'https://login.microsoftonline.eaglex.ic.gov/', clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255', - signInResourceId: 'https://management.core.eaglex.ic.gov/', microsoftResource: { id: 'marm', endpoint: 'https://management.azure.eaglex.ic.gov/', @@ -171,7 +168,6 @@ const germanyAzureSettings: ProviderSettings = { settings: { host: 'https://login.microsoftazure.de/', clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255', - signInResourceId: 'https://management.core.cloudapi.de/', graphResource: { id: 'https://graph.cloudapi.de/', endpoint: 'https://graph.cloudapi.de' @@ -197,7 +193,6 @@ const chinaAzureSettings: ProviderSettings = { settings: { host: 'https://login.chinacloudapi.cn/', clientId: 'a69788c6-1d43-44ed-9ca3-b83e194da255', - signInResourceId: 'https://management.core.chinacloudapi.cn/', graphResource: { id: 'https://graph.chinacloudapi.cn/', endpoint: 'https://graph.chinacloudapi.cn' diff --git a/extensions/azurecore/src/azureResource/commands.ts b/extensions/azurecore/src/azureResource/commands.ts index 1cdedafcf9..4c87827f6d 100644 --- a/extensions/azurecore/src/azureResource/commands.ts +++ b/extensions/azurecore/src/azureResource/commands.ts @@ -35,7 +35,6 @@ export function registerAzureResourceCommands(appContext: AppContext, tree: Azur const accountNode = node as AzureResourceAccountTreeNode; const azureAccount = accountNode.account as AzureAccount; - const tokens = await azdata.accounts.getSecurityToken(azureAccount, azdata.AzureResource.MicrosoftResourceManagement); const terminalService = appContext.getService(AzureResourceServiceNames.terminalService); @@ -64,7 +63,7 @@ export function registerAzureResourceCommands(appContext: AppContext, tree: Azur tenant = azureAccount.properties.tenants[listOfTenants.indexOf(pickedTenant)]; } - await terminalService.getOrCreateCloudConsole(azureAccount, tenant, tokens); + await terminalService.getOrCreateCloudConsole(azureAccount, tenant); } catch (ex) { console.error(ex); vscode.window.showErrorMessage(ex); @@ -91,13 +90,14 @@ export function registerAzureResourceCommands(appContext: AppContext, tree: Azur const subscriptions = (await accountNode.getCachedSubscriptions()) || []; if (subscriptions.length === 0) { try { - const tokens = await azdata.accounts.getSecurityToken(account, azdata.AzureResource.ResourceManagement); for (const tenant of account.properties.tenants) { - const token = tokens[tenant.id].token; - const tokenType = tokens[tenant.id].tokenType; + const response = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); - subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType))); + const token = response.token; + const tokenType = response.tokenType; + + subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType), tenant.id)); } } catch (error) { account.isStale = true; diff --git a/extensions/azurecore/src/azureResource/interfaces.ts b/extensions/azurecore/src/azureResource/interfaces.ts index ce2139a375..791f16d2a7 100644 --- a/extensions/azurecore/src/azureResource/interfaces.ts +++ b/extensions/azurecore/src/azureResource/interfaces.ts @@ -8,10 +8,10 @@ import * as msRest from '@azure/ms-rest-js'; import { Account } from 'azdata'; import { azureResource } from './azure-resource'; -import { AzureAccount, AzureAccountSecurityToken, Tenant } from '../account-provider/interfaces'; +import { AzureAccount, Tenant } from '../account-provider/interfaces'; export interface IAzureResourceSubscriptionService { - getSubscriptions(account: Account, credential: msRest.ServiceClientCredentials): Promise; + getSubscriptions(account: Account, credential: msRest.ServiceClientCredentials, tenantId: string): Promise; } export interface IAzureResourceSubscriptionFilterService { @@ -20,7 +20,7 @@ export interface IAzureResourceSubscriptionFilterService { } export interface IAzureTerminalService { - getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant, tokens: { [key: string]: AzureAccountSecurityToken }): Promise; + getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant): Promise; } export interface IAzureResourceCacheService { @@ -31,9 +31,6 @@ export interface IAzureResourceCacheService { update(key: string, value: T): void; } -export interface IAzureResourceTenantService { - getTenantId(subscription: azureResource.AzureResourceSubscription, account: Account, credential: msRest.ServiceClientCredentials): Promise; -} export interface IAzureResourceNodeWithProviderId { resourceProviderId: string; diff --git a/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts b/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts index 5e62cb43e5..27815fdd1d 100644 --- a/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts +++ b/extensions/azurecore/src/azureResource/providers/resourceTreeDataProviderBase.ts @@ -41,8 +41,8 @@ export abstract class ResourceTreeDataProviderBase { - const tokens = await azdata.accounts.getSecurityToken(element.account, azdata.AzureResource.ResourceManagement); - const credential = new msRest.TokenCredentials(tokens[element.tenantId].token, tokens[element.tenantId].tokenType); + const response = await azdata.accounts.getAccountSecurityToken(element.account, element.tenantId, azdata.AzureResource.ResourceManagement); + const credential = new msRest.TokenCredentials(response.token, response.tokenType); const resources: T[] = await this._resourceService.getResources(element.subscription, credential, element.account) || []; return resources; diff --git a/extensions/azurecore/src/azureResource/services/subscriptionService.ts b/extensions/azurecore/src/azureResource/services/subscriptionService.ts index 765f65628f..3d2ec0fb08 100644 --- a/extensions/azurecore/src/azureResource/services/subscriptionService.ts +++ b/extensions/azurecore/src/azureResource/services/subscriptionService.ts @@ -10,14 +10,15 @@ import { azureResource } from '../azure-resource'; import { IAzureResourceSubscriptionService } from '../interfaces'; export class AzureResourceSubscriptionService implements IAzureResourceSubscriptionService { - public async getSubscriptions(account: Account, credential: any): Promise { + public async getSubscriptions(account: Account, credential: any, tenantId: string): Promise { const subscriptions: azureResource.AzureResourceSubscription[] = []; const subClient = new SubscriptionClient(credential, { baseUri: account.properties.providerSettings.settings.armResource.endpoint }); const subs = await subClient.subscriptions.list(); subs.forEach((sub) => subscriptions.push({ id: sub.subscriptionId, - name: sub.displayName + name: sub.displayName, + tenant: tenantId })); return subscriptions; diff --git a/extensions/azurecore/src/azureResource/services/tenantService.ts b/extensions/azurecore/src/azureResource/services/tenantService.ts deleted file mode 100644 index ff1bbc8f7c..0000000000 --- a/extensions/azurecore/src/azureResource/services/tenantService.ts +++ /dev/null @@ -1,18 +0,0 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the Source EULA. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ -import { SubscriptionClient } from '@azure/arm-subscriptions'; - -import { azureResource } from '../azure-resource'; -import { IAzureResourceTenantService } from '../interfaces'; -import { Account } from 'azdata'; - -export class AzureResourceTenantService implements IAzureResourceTenantService { - public async getTenantId(subscription: azureResource.AzureResourceSubscription, account: Account, credentials: any): Promise { - const subClient = new SubscriptionClient(credentials, { baseUri: account.properties.providerSettings.settings.armResource.endpoint }); - - const result = await subClient.subscriptions.get(subscription.id); - return result.subscriptionId; - } -} diff --git a/extensions/azurecore/src/azureResource/services/terminalService.ts b/extensions/azurecore/src/azureResource/services/terminalService.ts index 9553b7acb2..9f6c2c0bcc 100644 --- a/extensions/azurecore/src/azureResource/services/terminalService.ts +++ b/extensions/azurecore/src/azureResource/services/terminalService.ts @@ -2,14 +2,14 @@ * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ - +import * as azdata from 'azdata'; import * as vscode from 'vscode'; import * as nls from 'vscode-nls'; import axios, { AxiosRequestConfig, AxiosResponse } from 'axios'; import * as WS from 'ws'; import { IAzureTerminalService } from '../interfaces'; -import { AzureAccount, AzureAccountSecurityToken, Tenant } from '../../account-provider/interfaces'; +import { AzureAccount, Tenant } from '../../account-provider/interfaces'; const localize = nls.loadMessageBundle(); @@ -48,13 +48,13 @@ export class AzureTerminalService implements IAzureTerminalService { } - public async getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant, tokens: { [key: string]: AzureAccountSecurityToken }): Promise { - const token = tokens[tenant.id].token; + public async getOrCreateCloudConsole(account: AzureAccount, tenant: Tenant): Promise { + const token = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.MicrosoftResourceManagement); const settings: AxiosRequestConfig = { headers: { 'Accept': 'application/json', 'Content-Type': 'application/json', - 'Authorization': `Bearer ${token}` + 'Authorization': `Bearer ${token.token}` }, validateStatus: () => true }; @@ -93,7 +93,7 @@ export class AzureTerminalService implements IAzureTerminalService { } const consoleUri = provisionResult.data.properties.uri; - return this.createTerminal(consoleUri, token, account.displayInfo.displayName, preferredShell); + return this.createTerminal(consoleUri, token.token, account.displayInfo.displayName, preferredShell); } diff --git a/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts b/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts index ca81b62da8..012938e865 100644 --- a/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts +++ b/extensions/azurecore/src/azureResource/tree/accountTreeNode.ts @@ -20,11 +20,12 @@ import { AzureResourceSubscriptionTreeNode } from './subscriptionTreeNode'; import { AzureResourceMessageTreeNode } from '../messageTreeNode'; import { AzureResourceErrorMessageUtil } from '../utils'; import { IAzureResourceTreeChangeHandler } from './treeChangeHandler'; -import { IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService, IAzureResourceTenantService } from '../../azureResource/interfaces'; +import { IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService } from '../../azureResource/interfaces'; +import { AzureAccount } from '../../account-provider/interfaces'; export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNodeBase { public constructor( - public readonly account: azdata.Account, + public readonly account: AzureAccount, appContext: AppContext, treeChangeHandler: IAzureResourceTreeChangeHandler ) { @@ -32,7 +33,6 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode this._subscriptionService = this.appContext.getService(AzureResourceServiceNames.subscriptionService); this._subscriptionFilterService = this.appContext.getService(AzureResourceServiceNames.subscriptionFilterService); - this._tenantService = this.appContext.getService(AzureResourceServiceNames.tenantService); this._id = `account_${this.account.key.accountId}`; this.setCacheKey(`${this._id}.subscriptions`); @@ -42,15 +42,13 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode public async getChildren(): Promise { try { let subscriptions: azureResource.AzureResourceSubscription[] = []; - const tokens = await azdata.accounts.getSecurityToken(this.account, azdata.AzureResource.ResourceManagement); if (this._isClearingCache) { try { for (const tenant of this.account.properties.tenants) { - const token = tokens[tenant.id].token; - const tokenType = tokens[tenant.id].tokenType; + const token = await azdata.accounts.getAccountSecurityToken(this.account, tenant.id, azdata.AzureResource.ResourceManagement); - subscriptions.push(...(await this._subscriptionService.getSubscriptions(this.account, new TokenCredentials(token, tokenType)) || [])); + subscriptions.push(...(await this._subscriptionService.getSubscriptions(this.account, new TokenCredentials(token.token, token.tokenType), tenant.id) || [])); } } catch (error) { throw new AzureResourceCredentialError(localize('azure.resource.tree.accountTreeNode.credentialError', "Failed to get credential for account {0}. Please refresh the account.", this.account.key.accountId), error); @@ -80,8 +78,8 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode return [AzureResourceMessageTreeNode.create(AzureResourceAccountTreeNode.noSubscriptionsLabel, this)]; } else { // Filter out everything that we can't authenticate to. - subscriptions = subscriptions.filter(s => { - const token = tokens[s.id]; + subscriptions = subscriptions.filter(async s => { + const token = await azdata.accounts.getAccountSecurityToken(this.account, s.tenant, azdata.AzureResource.ResourceManagement); if (!token) { console.info(`Account does not have permissions to view subscription ${JSON.stringify(s)}.`); return false; @@ -90,10 +88,7 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode }); let subTreeNodes = await Promise.all(subscriptions.map(async (subscription) => { - const token = tokens[subscription.id]; - const tenantId = await this._tenantService.getTenantId(subscription, this.account, new TokenCredentials(token.token, token.tokenType)); - - return new AzureResourceSubscriptionTreeNode(this.account, subscription, tenantId, this.appContext, this.treeChangeHandler, this); + return new AzureResourceSubscriptionTreeNode(this.account, subscription, subscription.tenant, this.appContext, this.treeChangeHandler, this); })); return subTreeNodes.sort((a, b) => a.subscription.name.localeCompare(b.subscription.name)); } @@ -166,7 +161,6 @@ export class AzureResourceAccountTreeNode extends AzureResourceContainerTreeNode private _subscriptionService: IAzureResourceSubscriptionService = undefined; private _subscriptionFilterService: IAzureResourceSubscriptionFilterService = undefined; - private _tenantService: IAzureResourceTenantService = undefined; private _id: string = undefined; private _label: string = undefined; diff --git a/extensions/azurecore/src/azureResource/utils.ts b/extensions/azurecore/src/azureResource/utils.ts index 605f69bbe6..dfa7ee76f8 100644 --- a/extensions/azurecore/src/azureResource/utils.ts +++ b/extensions/azurecore/src/azureResource/utils.ts @@ -151,13 +151,13 @@ export async function getSubscriptions(appContext: AppContext, account?: azdata. } const subscriptionService = appContext.getService(AzureResourceServiceNames.subscriptionService); - const tokens = await azdata.accounts.getSecurityToken(account, azdata.AzureResource.ResourceManagement); - await Promise.all(account.properties.tenants.map(async (tenant: { id: string | number; }) => { + await Promise.all(account.properties.tenants.map(async (tenant: { id: string; }) => { try { - const token = tokens[tenant.id].token; - const tokenType = tokens[tenant.id].tokenType; + const response = await azdata.accounts.getAccountSecurityToken(account, tenant.id, azdata.AzureResource.ResourceManagement); + const token = response.token; + const tokenType = response.tokenType; - result.subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType))); + result.subscriptions.push(...await subscriptionService.getSubscriptions(account, new TokenCredentials(token, tokenType), tenant.id)); } catch (err) { const error = new Error(localize('azure.accounts.getSubscriptions.queryError', "Error fetching subscriptions for account {0} tenant {1} : {2}", account.displayInfo.displayName, diff --git a/extensions/azurecore/src/extension.ts b/extensions/azurecore/src/extension.ts index 89d7add7f9..bb6690953f 100644 --- a/extensions/azurecore/src/extension.ts +++ b/extensions/azurecore/src/extension.ts @@ -17,12 +17,11 @@ import { AzureResourceDatabaseServerService } from './azureResource/providers/da import { AzureResourceDatabaseProvider } from './azureResource/providers/database/databaseProvider'; import { AzureResourceDatabaseService } from './azureResource/providers/database/databaseService'; import { AzureResourceService } from './azureResource/resourceService'; -import { IAzureResourceCacheService, IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService, IAzureResourceTenantService, IAzureTerminalService } from './azureResource/interfaces'; +import { IAzureResourceCacheService, IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService, IAzureTerminalService } from './azureResource/interfaces'; import { AzureResourceServiceNames } from './azureResource/constants'; import { AzureResourceSubscriptionService } from './azureResource/services/subscriptionService'; import { AzureResourceSubscriptionFilterService } from './azureResource/services/subscriptionFilterService'; import { AzureResourceCacheService } from './azureResource/services/cacheService'; -import { AzureResourceTenantService } from './azureResource/services/tenantService'; import { registerAzureResourceCommands } from './azureResource/commands'; import { AzureResourceTreeProvider } from './azureResource/tree/treeProvider'; import { SqlInstanceResourceService } from './azureResource/providers/sqlinstance/sqlInstanceService'; @@ -153,7 +152,6 @@ function registerAzureServices(appContext: AppContext): void { appContext.registerService(AzureResourceServiceNames.cacheService, new AzureResourceCacheService(extensionContext)); appContext.registerService(AzureResourceServiceNames.subscriptionService, new AzureResourceSubscriptionService()); appContext.registerService(AzureResourceServiceNames.subscriptionFilterService, new AzureResourceSubscriptionFilterService(new AzureResourceCacheService(extensionContext))); - appContext.registerService(AzureResourceServiceNames.tenantService, new AzureResourceTenantService()); appContext.registerService(AzureResourceServiceNames.terminalService, new AzureTerminalService(extensionContext)); } diff --git a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts index 9b9ec4673c..82cc7f4612 100644 --- a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts +++ b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts @@ -4,114 +4,240 @@ *--------------------------------------------------------------------------------------------*/ import * as should from 'should'; -import * as os from 'os'; +import * as TypeMoq from 'typemoq'; +// import * as azdata from 'azdata'; +// import * as vscode from 'vscode'; +// import * as sinon from 'sinon'; import 'mocha'; -import * as vscode from 'vscode'; - -import { PromptFailedResult, AccountKey } from 'azdata'; -import { AzureAuth, AccessToken, RefreshToken, TokenClaims, TokenRefreshResponse } from '../../../account-provider/auths/azureAuth'; -import { AzureAccount, AzureAuthType, Deferred, Tenant } from '../../../account-provider/interfaces'; +import { AzureAuthCodeGrant } from '../../../account-provider/auths/azureAuthCodeGrant'; +// import { AzureDeviceCode } from '../../../account-provider/auths/azureDeviceCode'; +import { Token, TokenClaims, AccessToken, RefreshToken, OAuthTokenResponse, TokenPostData } from '../../../account-provider/auths/azureAuth'; +import { Tenant, AzureAccount } from '../../../account-provider/interfaces'; import providerSettings from '../../../account-provider/providerSettings'; -import { SimpleTokenCache } from '../../../account-provider/simpleTokenCache'; -import { CredentialsTestProvider } from '../../stubs/credentialsTestProvider'; +import { AzureResource } from 'azdata'; +import { AxiosResponse } from 'axios'; -class BasicAzureAuth extends AzureAuth { - public async login(): Promise { - throw new Error('Method not implemented.'); - } - public async autoOAuthCancelled(): Promise { - throw new Error('Method not implemented.'); - } +let azureAuthCodeGrant: TypeMoq.IMock; +// let azureDeviceCode: TypeMoq.IMock; - public async promptForConsent(): Promise<{ tokenRefreshResponse: TokenRefreshResponse, authCompleteDeferred: Deferred } | undefined> { - throw new Error('Method not implemented.'); - } -} +const mockToken: Token = { + key: 'someUniqueId', + token: 'test_token', + tokenType: 'Bearer' +}; +let mockAccessToken: AccessToken; +let mockRefreshToken: RefreshToken; -let baseAuth: AzureAuth; +const mockClaims = { + name: 'Name', + email: 'example@example.com', + sub: 'someUniqueId' +} as TokenClaims; -const accountKey: AccountKey = { - accountId: 'SomeAccountKey', - providerId: 'providerId', +const mockTenant: Tenant = { + displayName: 'Tenant Name', + id: 'tenantID', + tenantCategory: 'Home', + userId: 'test_user' }; -const accessToken: AccessToken = { - key: accountKey.accountId, - token: '123' -}; +let mockAccount: AzureAccount; -const refreshToken: RefreshToken = { - key: accountKey.accountId, - token: '321' -}; +const provider = providerSettings[0].metadata; -const resourceId = 'resource'; -const tenantId = 'tenant'; +describe('Azure Authentication', function () { + beforeEach(function () { + azureAuthCodeGrant = TypeMoq.Mock.ofType(AzureAuthCodeGrant, TypeMoq.MockBehavior.Loose, true, provider); + // azureDeviceCode = TypeMoq.Mock.ofType(); -const tenant: Tenant = { - id: tenantId, - displayName: 'common' -}; + azureAuthCodeGrant.callBase = true; + // authDeviceCode.callBase = true; -// These tests don't work on Linux systems because gnome-keyring doesn't like running on headless machines. -describe('AccountProvider.AzureAuth', function (): void { - beforeEach(async function (): Promise { - const tokenCache = new SimpleTokenCache('testTokenService', os.tmpdir(), true, new CredentialsTestProvider()); - await tokenCache.init(); - baseAuth = new BasicAzureAuth(providerSettings[0].metadata, tokenCache, undefined, undefined, AzureAuthType.AuthCodeGrant, 'Auth Code Grant'); + mockAccount = { + isStale: false, + properties: { + tenants: [mockTenant] + } + } as AzureAccount; + + mockAccessToken = { + ...mockToken + }; + mockRefreshToken = { + ...mockToken + }; }); - it('Basic token set and get', async function (): Promise { - await baseAuth.setCachedToken(accountKey, accessToken, refreshToken); - const result = await baseAuth.getCachedToken(accountKey); + it('accountHydration should yield a valid account', async function () { - should(JSON.stringify(result.accessToken)).be.equal(JSON.stringify(accessToken)); - should(JSON.stringify(result.refreshToken)).be.equal(JSON.stringify(refreshToken)); + azureAuthCodeGrant.setup(x => x.getTenants(mockToken)).returns((): Promise => { + return Promise.resolve([ + mockTenant + ]); + }); + + const response = await azureAuthCodeGrant.object.hydrateAccount(mockToken, mockClaims); + should(response.displayInfo.displayName).be.equal(`${mockClaims.name} - ${mockClaims.email}`, 'Account name should match'); + should(response.displayInfo.userId).be.equal(mockClaims.sub, 'Account ID should match'); + should(response.properties.tenants).be.deepEqual([mockTenant], 'Tenants should match'); }); - it('Token set and get with tenant and resource id', async function (): Promise { - await baseAuth.setCachedToken(accountKey, accessToken, refreshToken, resourceId, tenantId); - let result = await baseAuth.getCachedToken(accountKey, resourceId, tenantId); + describe('getAccountSecurityToken', function () { + it('should be undefined on stale account', async function () { + mockAccount.isStale = true; + const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, TypeMoq.It.isAny(), TypeMoq.It.isAny()); + should(securityToken).be.undefined(); + }); + it('dont find correct resources', async function () { + const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, TypeMoq.It.isAny(), -1); + should(securityToken).be.undefined(); + }); + it('incorrect tenant', async function () { + await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, 'invalid_tenant', AzureResource.MicrosoftResourceManagement).should.be.rejected(); + }); - should(JSON.stringify(result.accessToken)).be.equal(JSON.stringify(accessToken)); - should(JSON.stringify(result.refreshToken)).be.equal(JSON.stringify(refreshToken)); + it('saved token exists and can be reused', async function () { + delete (mockAccessToken as any).tokenType; + azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => { + return Promise.resolve({ + accessToken: mockAccessToken, + refreshToken: mockRefreshToken, + expiresOn: `${(new Date().getTime() / 1000) + (10 * 60)}` + }) + }); + const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement); + + should(securityToken.tokenType).be.equal('Bearer', 'tokenType should be bearer on a successful getSecurityToken from cache') + }); + + + it('saved token had invalid expiration', async function () { + delete (mockAccessToken as any).tokenType; + (mockAccessToken as any).invalidData = 'this should not exist on response'; + azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => { + return Promise.resolve({ + accessToken: mockAccessToken, + refreshToken: mockRefreshToken, + expiresOn: undefined + }); + }); + azureAuthCodeGrant.setup(x => x.refreshToken(mockTenant, provider.settings.microsoftResource, mockRefreshToken)).returns((): Promise => { + const mockToken: AccessToken = JSON.parse(JSON.stringify(mockAccessToken)); + delete (mockToken as any).invalidData; + return Promise.resolve({ + accessToken: mockToken + } as OAuthTokenResponse); + }); + const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement); + + should((securityToken as any).invalidData).be.undefined(); // Ensure its a new one + should(securityToken.tokenType).be.equal('Bearer', 'tokenType should be bearer on a successful getSecurityToken from cache') + + azureAuthCodeGrant.verify(x => x.refreshToken(mockTenant, provider.settings.microsoftResource, mockRefreshToken), TypeMoq.Times.once()); + }); + + describe('no saved token', function () { + it('no base token', async function () { + azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => { + return Promise.resolve(undefined); + }); + + azureAuthCodeGrant.setup(x => x.getSavedToken(azureAuthCodeGrant.object.commonTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => { + return Promise.resolve(undefined); + }); + + await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement).should.be.rejected(); + }); + + it('base token exists', async function () { + azureAuthCodeGrant.setup(x => x.getSavedToken(mockTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => { + return Promise.resolve(undefined); + }); + + azureAuthCodeGrant.setup(x => x.getSavedToken(azureAuthCodeGrant.object.commonTenant, provider.settings.microsoftResource, mockAccount.key)).returns((): Promise<{ accessToken: AccessToken, refreshToken: RefreshToken, expiresOn: string }> => { + return Promise.resolve({ + accessToken: mockAccessToken, + refreshToken: mockRefreshToken, + expiresOn: '' + }); + }); + delete (mockAccessToken as any).tokenType; + + azureAuthCodeGrant.setup(x => x.refreshToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return Promise.resolve({ + accessToken: mockAccessToken + } as OAuthTokenResponse); + }); + + const securityToken: Token = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement); + should(securityToken.tokenType).be.equal('Bearer', 'tokenType should be bearer on a successful getSecurityToken from cache') + }); + }); - await baseAuth.clearCredentials(accountKey); - result = await baseAuth.getCachedToken(accountKey, resourceId, tenantId); - should(result).be.undefined(); }); - it('Token set with resource ID and get without tenant and resource id', async function (): Promise { - await baseAuth.setCachedToken(accountKey, accessToken, refreshToken, resourceId, tenantId); - const result = await baseAuth.getCachedToken(accountKey); + describe('getToken', function () { - should(JSON.stringify(result)).be.undefined(); - should(JSON.stringify(result)).be.undefined(); + it('calls handle interaction required', async function () { + azureAuthCodeGrant.setup(x => x.makePostRequest(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return Promise.resolve({ + data: { + error: 'interaction_required' + } + } as AxiosResponse); + }); + + azureAuthCodeGrant.setup(x => x.handleInteractionRequired(mockTenant, provider.settings.microsoftResource)).returns(() => { + return Promise.resolve({ + accessToken: mockAccessToken + } as OAuthTokenResponse); + }); + + + const result = await azureAuthCodeGrant.object.getToken(mockTenant, provider.settings.microsoftResource, {} as TokenPostData); + + azureAuthCodeGrant.verify(x => x.handleInteractionRequired(mockTenant, provider.settings.microsoftResource), TypeMoq.Times.once()); + + should(result.accessToken).be.deepEqual(mockAccessToken); + }); + + it('unknown error should throw error', async function () { + azureAuthCodeGrant.setup(x => x.makePostRequest(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return Promise.resolve({ + data: { + error: 'unknown error' + } + } as AxiosResponse); + }); + + await azureAuthCodeGrant.object.getToken(mockTenant, provider.settings.microsoftResource, {} as TokenPostData).should.be.rejected(); + }); + + it('calls getTokenHelper', async function () { + azureAuthCodeGrant.setup(x => x.makePostRequest(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return Promise.resolve({ + data: { + access_token: mockAccessToken.token, + refresh_token: mockRefreshToken.token, + expires_on: `0` + } + } as AxiosResponse); + }); + + azureAuthCodeGrant.setup(x => x.getTokenHelper(mockTenant, provider.settings.microsoftResource, TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return Promise.resolve({ + accessToken: mockAccessToken + } as OAuthTokenResponse); + }); + + + const result = await azureAuthCodeGrant.object.getToken(mockTenant, provider.settings.microsoftResource, {} as TokenPostData); + + azureAuthCodeGrant.verify(x => x.getTokenHelper(mockTenant, provider.settings.microsoftResource, TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); + + should(result.accessToken).be.deepEqual(mockAccessToken); + }); }); - it('Create an account object', async function (): Promise { - const tokenClaims = { - idp: 'live.com', - name: 'TestAccount', - } as TokenClaims; - - const account = baseAuth.createAccount(tokenClaims, 'someKey', undefined); - - should(account.properties.azureAuthType).be.equal(AzureAuthType.AuthCodeGrant); - should(account.key.accountId).be.equal('someKey'); - should(account.properties.isMsAccount).be.equal(true); - }); - it('Should handle ignored tenants', async function (): Promise { - // Don't sit on the await openConsentDialog if test is failing - this.timeout(3000); - - const configuration = vscode.workspace.getConfiguration('azure.tenant.config'); - const values = [tenantId]; - - await configuration.update('filter', values, vscode.ConfigurationTarget.Global); - const x = await baseAuth.openConsentDialog(tenant, resourceId); - - should(x).be.false(); - }); }); diff --git a/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts b/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts index e40fbddfa4..5cf1b7b561 100644 --- a/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts +++ b/extensions/azurecore/src/test/azureResource/providers/database/databaseTreeDataProvider.test.ts @@ -41,12 +41,15 @@ const mockAccount: AzureAccount = { isStale: false }; +const mockTenantId: string = 'mock_tenant'; + + const mockSubscription: azureResource.AzureResourceSubscription = { id: 'mock_subscription', - name: 'mock subscription' + name: 'mock subscription', + tenant: mockTenantId }; -const mockTenantId: string = 'mock_tenant'; const mockResourceRootNode: azureResource.IAzureResourceNode = { account: mockAccount, @@ -61,8 +64,7 @@ const mockResourceRootNode: azureResource.IAzureResourceNode = { } }; -const mockTokens: { [key: string]: any } = {}; -mockTokens[mockTenantId] = { +const mockToken = { token: 'mock_token', tokenType: 'Bearer' }; @@ -106,7 +108,7 @@ describe('AzureResourceDatabaseTreeDataProvider.getChildren', function (): void mockDatabaseService = TypeMoq.Mock.ofType>(); mockExtensionContext = TypeMoq.Mock.ofType(); - sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens)); + sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken)); mockDatabaseService.setup((o) => o.getResources(mockSubscription, TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockDatabases)); mockExtensionContext.setup((o) => o.asAbsolutePath(TypeMoq.It.isAnyString())).returns(() => TypeMoq.It.isAnyString()); }); diff --git a/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts b/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts index 330b47b2cf..092a970472 100644 --- a/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts +++ b/extensions/azurecore/src/test/azureResource/providers/databaseServer/databaseServerTreeDataProvider.test.ts @@ -41,12 +41,14 @@ const mockAccount: AzureAccount = { isStale: false }; +const mockTenantId: string = 'mock_tenant'; + const mockSubscription: azureResource.AzureResourceSubscription = { id: 'mock_subscription', - name: 'mock subscription' + name: 'mock subscription', + tenant: mockTenantId }; -const mockTenantId: string = 'mock_tenant'; const mockResourceRootNode: azureResource.IAzureResourceNode = { account: mockAccount, @@ -61,8 +63,7 @@ const mockResourceRootNode: azureResource.IAzureResourceNode = { } }; -const mockTokens: { [key: string]: any } = {}; -mockTokens[mockTenantId] = { +const mockToken = { token: 'mock_token', tokenType: 'Bearer' }; @@ -106,7 +107,7 @@ describe('AzureResourceDatabaseServerTreeDataProvider.getChildren', function (): mockDatabaseServerService = TypeMoq.Mock.ofType>(); mockExtensionContext = TypeMoq.Mock.ofType(); - sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens)); + sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken)); mockDatabaseServerService.setup((o) => o.getResources(mockSubscription, TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockDatabaseServers)); mockExtensionContext.setup((o) => o.asAbsolutePath(TypeMoq.It.isAnyString())).returns(() => TypeMoq.It.isAnyString()); }); diff --git a/extensions/azurecore/src/test/azureResource/resourceService.test.ts b/extensions/azurecore/src/test/azureResource/resourceService.test.ts index 061c4cb2d1..723263fab1 100644 --- a/extensions/azurecore/src/test/azureResource/resourceService.test.ts +++ b/extensions/azurecore/src/test/azureResource/resourceService.test.ts @@ -33,13 +33,14 @@ const mockAccount: AzureAccount = { isStale: false }; +const mockTenantId: string = 'mock_tenant'; + const mockSubscription: azureResource.AzureResourceSubscription = { id: 'mock_subscription', - name: 'mock subscription' + name: 'mock subscription', + tenant: mockTenantId }; -const mockTenantId: string = 'mock_tenant'; - let mockResourceTreeDataProvider1: TypeMoq.IMock; let mockResourceProvider1: TypeMoq.IMock; diff --git a/extensions/azurecore/src/test/azureResource/resourceTreeNode.test.ts b/extensions/azurecore/src/test/azureResource/resourceTreeNode.test.ts index f16a10a9fa..803e7050f9 100644 --- a/extensions/azurecore/src/test/azureResource/resourceTreeNode.test.ts +++ b/extensions/azurecore/src/test/azureResource/resourceTreeNode.test.ts @@ -36,12 +36,14 @@ const mockAccount: AzureAccount = { isStale: false }; +const mockTenantId: string = 'mock_tenant'; + const mockSubscription: azureResource.AzureResourceSubscription = { id: 'mock_subscription', - name: 'mock subscription' + name: 'mock subscription', + tenant: mockTenantId }; -const mockTenantId: string = 'mock_tenant'; const mockResourceProviderId: string = 'mock_resource_provider'; diff --git a/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts b/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts index 29cb04503d..52cb4dfe59 100644 --- a/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts +++ b/extensions/azurecore/src/test/azureResource/tree/accountTreeNode.test.ts @@ -17,7 +17,6 @@ import { IAzureResourceCacheService, IAzureResourceSubscriptionService, IAzureResourceSubscriptionFilterService, - IAzureResourceTenantService } from '../../../azureResource/interfaces'; import { IAzureResourceTreeChangeHandler } from '../../../azureResource/tree/treeChangeHandler'; import { AzureResourceAccountTreeNode } from '../../../azureResource/tree/accountTreeNode'; @@ -31,7 +30,6 @@ let mockExtensionContext: TypeMoq.IMock; let mockCacheService: TypeMoq.IMock; let mockSubscriptionService: TypeMoq.IMock; let mockSubscriptionFilterService: TypeMoq.IMock; -let mockTenantService: TypeMoq.IMock; let mockAppContext: AppContext; let getSecurityTokenStub: sinon.SinonStub; let mockTreeChangeHandler: TypeMoq.IMock; @@ -63,28 +61,27 @@ const mockAccount: azdata.Account = { const mockSubscription1: azureResource.AzureResourceSubscription = { id: 'mock_subscription_1', - name: 'mock subscription 1' + name: 'mock subscription 1', + tenant: mockTenantId }; const mockSubscription2: azureResource.AzureResourceSubscription = { id: 'mock_subscription_2', - name: 'mock subscription 2' + name: 'mock subscription 2', + tenant: mockTenantId }; const mockSubscriptions = [mockSubscription1, mockSubscription2]; const mockFilteredSubscriptions = [mockSubscription1]; -const mockTokens: { [key: string]: any } = {}; +const mockToken = { + token: 'mock_token', + tokenType: 'Bearer' +}; -[mockSubscription1.id, mockSubscription2.id, mockTenantId].forEach(s => { - mockTokens[s] = { - token: 'mock_token', - tokenType: 'Bearer' - }; -}); -const mockCredential = new TokenCredentials(mockTokens[mockTenantId].token, mockTokens[mockTenantId].tokenType); +const mockCredential = new TokenCredentials(mockToken.token, mockToken.tokenType); let mockSubscriptionCache: azureResource.AzureResourceSubscription[] = []; @@ -94,7 +91,6 @@ describe('AzureResourceAccountTreeNode.info', function (): void { mockCacheService = TypeMoq.Mock.ofType(); mockSubscriptionService = TypeMoq.Mock.ofType(); mockSubscriptionFilterService = TypeMoq.Mock.ofType(); - mockTenantService = TypeMoq.Mock.ofType(); mockTreeChangeHandler = TypeMoq.Mock.ofType(); @@ -104,13 +100,11 @@ describe('AzureResourceAccountTreeNode.info', function (): void { mockAppContext.registerService(AzureResourceServiceNames.cacheService, mockCacheService.object); mockAppContext.registerService(AzureResourceServiceNames.subscriptionService, mockSubscriptionService.object); mockAppContext.registerService(AzureResourceServiceNames.subscriptionFilterService, mockSubscriptionFilterService.object); - mockAppContext.registerService(AzureResourceServiceNames.tenantService, mockTenantService.object); - getSecurityTokenStub = sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens)); + getSecurityTokenStub = sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken)); mockCacheService.setup((o) => o.generateKey(TypeMoq.It.isAnyString())).returns(() => generateGuid()); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache = mockSubscriptions); - mockTenantService.setup((o) => o.getTenantId(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockTenantId)); }); afterEach(function (): void { @@ -138,7 +132,7 @@ describe('AzureResourceAccountTreeNode.info', function (): void { }); it('Should be correct when there are subscriptions listed.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(undefined)); const accountTreeNodeLabel = `${mockAccount.displayInfo.displayName} (${mockSubscriptions.length} / ${mockSubscriptions.length} subscriptions)`; @@ -158,7 +152,7 @@ describe('AzureResourceAccountTreeNode.info', function (): void { }); it('Should be correct when there are subscriptions filtered.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(mockFilteredSubscriptions)); const accountTreeNodeLabel = `${mockAccount.displayInfo.displayName} (${mockFilteredSubscriptions.length} / ${mockSubscriptions.length} subscriptions)`; @@ -184,7 +178,6 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { mockCacheService = TypeMoq.Mock.ofType(); mockSubscriptionService = TypeMoq.Mock.ofType(); mockSubscriptionFilterService = TypeMoq.Mock.ofType(); - mockTenantService = TypeMoq.Mock.ofType(); mockTreeChangeHandler = TypeMoq.Mock.ofType(); @@ -194,13 +187,11 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { mockAppContext.registerService(AzureResourceServiceNames.cacheService, mockCacheService.object); mockAppContext.registerService(AzureResourceServiceNames.subscriptionService, mockSubscriptionService.object); mockAppContext.registerService(AzureResourceServiceNames.subscriptionFilterService, mockSubscriptionFilterService.object); - mockAppContext.registerService(AzureResourceServiceNames.tenantService, mockTenantService.object); - sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens)); + sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken)); mockCacheService.setup((o) => o.generateKey(TypeMoq.It.isAnyString())).returns(() => generateGuid()); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache = mockSubscriptions); - mockTenantService.setup((o) => o.getTenantId(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockTenantId)); }); afterEach(function (): void { @@ -208,14 +199,14 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should load subscriptions from scratch and update cache when it is clearing cache.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve([])); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); const children = await accountTreeNode.getChildren(); - mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential), TypeMoq.Times.once()); + mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.exactly(0)); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockSubscriptionFilterService.verify((o) => o.getSelectedSubscriptions(mockAccount), TypeMoq.Times.once()); @@ -241,7 +232,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should load subscriptions from cache when it is not clearing cache.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(undefined)); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); @@ -250,7 +241,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { const children = await accountTreeNode.getChildren(); - mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential), TypeMoq.Times.once()); + mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.once()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); @@ -262,7 +253,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should handle when there is no subscriptions.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(undefined)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(undefined)); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); @@ -278,7 +269,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should honor subscription filtering.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions)); mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => Promise.resolve(mockFilteredSubscriptions)); const accountTreeNode = new AzureResourceAccountTreeNode(mockAccount, mockAppContext, mockTreeChangeHandler.object); @@ -296,7 +287,7 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { }); it('Should handle errors.', async function (): Promise { - mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential)).returns(() => Promise.resolve(mockSubscriptions)); + mockSubscriptionService.setup((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId)).returns(() => Promise.resolve(mockSubscriptions)); const mockError = 'Test error'; mockSubscriptionFilterService.setup((o) => o.getSelectedSubscriptions(mockAccount)).returns(() => { throw new Error(mockError); }); @@ -305,8 +296,8 @@ describe('AzureResourceAccountTreeNode.getChildren', function (): void { const children = await accountTreeNode.getChildren(); - should(getSecurityTokenStub.calledOnce).be.true('getSecurityToken should have been called exactly once'); - mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential), TypeMoq.Times.once()); + should(getSecurityTokenStub.calledTwice).be.true('getSecurityToken should have been called exactly twice - once per subscription'); + mockSubscriptionService.verify((o) => o.getSubscriptions(mockAccount, mockCredential, mockTenantId), TypeMoq.Times.once()); mockSubscriptionFilterService.verify((o) => o.getSelectedSubscriptions(mockAccount), TypeMoq.Times.once()); mockCacheService.verify((o) => o.get(TypeMoq.It.isAnyString()), TypeMoq.Times.never()); mockCacheService.verify((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny()), TypeMoq.Times.once()); @@ -325,7 +316,6 @@ describe('AzureResourceAccountTreeNode.clearCache', function (): void { mockCacheService = TypeMoq.Mock.ofType(); mockSubscriptionService = TypeMoq.Mock.ofType(); mockSubscriptionFilterService = TypeMoq.Mock.ofType(); - mockTenantService = TypeMoq.Mock.ofType(); mockTreeChangeHandler = TypeMoq.Mock.ofType(); @@ -335,13 +325,11 @@ describe('AzureResourceAccountTreeNode.clearCache', function (): void { mockAppContext.registerService(AzureResourceServiceNames.cacheService, mockCacheService.object); mockAppContext.registerService(AzureResourceServiceNames.subscriptionService, mockSubscriptionService.object); mockAppContext.registerService(AzureResourceServiceNames.subscriptionFilterService, mockSubscriptionFilterService.object); - mockAppContext.registerService(AzureResourceServiceNames.tenantService, mockTenantService.object); - sinon.stub(azdata.accounts, 'getSecurityToken').returns(Promise.resolve(mockTokens)); + sinon.stub(azdata.accounts, 'getAccountSecurityToken').returns(Promise.resolve(mockToken)); mockCacheService.setup((o) => o.generateKey(TypeMoq.It.isAnyString())).returns(() => generateGuid()); mockCacheService.setup((o) => o.get(TypeMoq.It.isAnyString())).returns(() => mockSubscriptionCache); mockCacheService.setup((o) => o.update(TypeMoq.It.isAnyString(), TypeMoq.It.isAny())).returns(() => mockSubscriptionCache = mockSubscriptions); - mockTenantService.setup((o) => o.getTenantId(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(mockTenantId)); }); afterEach(function (): void { diff --git a/extensions/azurecore/src/test/azureResource/tree/subscriptionTreeNode.test.ts b/extensions/azurecore/src/test/azureResource/tree/subscriptionTreeNode.test.ts index 49ce2c37b4..d9f8e187aa 100644 --- a/extensions/azurecore/src/test/azureResource/tree/subscriptionTreeNode.test.ts +++ b/extensions/azurecore/src/test/azureResource/tree/subscriptionTreeNode.test.ts @@ -43,13 +43,14 @@ const mockAccount: azdata.Account = { isStale: false }; +const mockTenantId: string = 'mock_tenant'; + const mockSubscription: azureResource.AzureResourceSubscription = { id: 'mock_subscription', - name: 'mock subscription' + name: 'mock subscription', + tenant: mockTenantId }; -const mockTenantId: string = 'mock_tenant'; - let mockResourceTreeDataProvider1: TypeMoq.IMock; let mockResourceProvider1: TypeMoq.IMock; diff --git a/extensions/azurecore/src/utils/Logger.ts b/extensions/azurecore/src/utils/Logger.ts index 2853273fda..37893bb677 100644 --- a/extensions/azurecore/src/utils/Logger.ts +++ b/extensions/azurecore/src/utils/Logger.ts @@ -7,17 +7,13 @@ export class Logger { private static _piiLogging: boolean = false; static log(msg: any, ...vals: any[]) { - if (vals && vals.length > 0) { - return console.log(msg, vals); - } - console.log(msg); + const fullMessage = `${msg} - ${vals.map(v => JSON.stringify(v)).join(' - ')}`; + console.log(fullMessage); } static error(msg: any, ...vals: any[]) { - if (vals && vals.length > 0) { - return console.error(msg, vals); - } - console.error(msg); + const fullMessage = `${msg} - ${vals.map(v => JSON.stringify(v)).join(' - ')}`; + console.error(fullMessage); } static pii(msg: any, ...vals: any[]) { diff --git a/extensions/mssql/src/features.ts b/extensions/mssql/src/features.ts index 410affa69e..23b27e8c86 100644 --- a/extensions/mssql/src/features.ts +++ b/extensions/mssql/src/features.ts @@ -70,22 +70,22 @@ export class AccountFeature implements StaticFeature { account = accountList[0]; } - const securityToken: { [key: string]: any } = await azdata.accounts.getSecurityToken(account, azdata.AzureResource.AzureKeyVault); const tenant = account.properties.tenants.find((t: { [key: string]: string }) => request.authority.includes(t.id)); const unauthorizedMessage = localize('mssql.insufficientlyPrivelagedAzureAccount', "The configured Azure account for {0} does not have sufficient permissions for Azure Key Vault to access a column master key for Always Encrypted.", account.key.accountId); if (!tenant) { window.showErrorMessage(unauthorizedMessage); return undefined; } - let tokenBundle = securityToken[tenant.id]; - if (!tokenBundle) { + const securityToken = await azdata.accounts.getAccountSecurityToken(account, tenant, azdata.AzureResource.AzureKeyVault); + + if (!securityToken?.token) { window.showErrorMessage(unauthorizedMessage); return undefined; } let params: contracts.RequestSecurityTokenResponse = { accountKey: JSON.stringify(account.key), - token: securityToken[tenant.id].token + token: securityToken.token }; return params; diff --git a/src/sql/azdata.d.ts b/src/sql/azdata.d.ts index 213e752c24..5caf36b8e7 100644 --- a/src/sql/azdata.d.ts +++ b/src/sql/azdata.d.ts @@ -2131,9 +2131,18 @@ declare module 'azdata' { * @param account Account to generate security token for (defaults to * AzureResource.ResourceManagement if not given) * @return Promise to return the security token + * @deprecated use getAccountSecurityToken */ export function getSecurityToken(account: Account, resource?: AzureResource): Thenable<{ [key: string]: any }>; + /** + * Generates a security token by asking the account's provider + * @param account + * @param tenant + * @param resource + */ + export function getAccountSecurityToken(account: Account, tenant: string, resource: AzureResource): Thenable<{ token: string, tokenType?: string } | undefined>; + /** * An [event](#Event) which fires when the accounts have changed. */ @@ -2279,6 +2288,7 @@ declare module 'azdata' { * @param account The account to generate a security token for * @param resource The resource to get the token for * @return Promise to return a security token object + * @deprecated use getAccountSecurityToken */ getSecurityToken(account: Account, resource: AzureResource): Thenable<{} | undefined>; diff --git a/src/sql/azdata.proposed.d.ts b/src/sql/azdata.proposed.d.ts index 7966e1dc0e..48431f315c 100644 --- a/src/sql/azdata.proposed.d.ts +++ b/src/sql/azdata.proposed.d.ts @@ -494,4 +494,14 @@ declare module 'azdata' { name?: string; } + export interface AccountProvider { + /** + * Generates a security token for the provided account and tenant + * @param account The account to generate a security token for + * @param resource The resource to get the token for + * @return Promise to return a security token object + */ + getAccountSecurityToken(account: Account, tenant: string, resource: AzureResource): Thenable<{ token: string } | undefined>; + } + } diff --git a/src/sql/platform/accounts/common/interfaces.ts b/src/sql/platform/accounts/common/interfaces.ts index 4057116cd6..b930668aa9 100644 --- a/src/sql/platform/accounts/common/interfaces.ts +++ b/src/sql/platform/accounts/common/interfaces.ts @@ -21,7 +21,11 @@ export interface IAccountManagementService { getAccountProviderMetadata(): Thenable; getAccountsForProvider(providerId: string): Thenable; getAccounts(): Thenable; + /** + * @deprecated + */ getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{ [key: string]: { token: string } }>; + getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }>; removeAccount(accountKey: azdata.AccountKey): Thenable; removeAccounts(): Thenable; refreshAccount(account: azdata.Account): Thenable; diff --git a/src/sql/platform/accounts/test/common/testAccountManagementService.ts b/src/sql/platform/accounts/test/common/testAccountManagementService.ts index 93992ac7dd..2800464214 100644 --- a/src/sql/platform/accounts/test/common/testAccountManagementService.ts +++ b/src/sql/platform/accounts/test/common/testAccountManagementService.ts @@ -55,6 +55,10 @@ export class TestAccountManagementService implements IAccountManagementService { return Promise.resolve([]); } + getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> { + return Promise.resolve(undefined); + } + removeAccount(accountKey: azdata.AccountKey): Thenable { throw new Error('Method not implemented'); } @@ -100,6 +104,9 @@ export class AccountProviderStub implements azdata.AccountProvider { return Promise.resolve({}); } + getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> { + return Promise.resolve(undefined); + } initialize(storedAccounts: azdata.Account[]): Thenable { return Promise.resolve(storedAccounts); } diff --git a/src/sql/workbench/api/browser/mainThreadAccountManagement.ts b/src/sql/workbench/api/browser/mainThreadAccountManagement.ts index 812fcd304b..9b56f2cc8d 100644 --- a/src/sql/workbench/api/browser/mainThreadAccountManagement.ts +++ b/src/sql/workbench/api/browser/mainThreadAccountManagement.ts @@ -75,9 +75,13 @@ export class MainThreadAccountManagement extends Disposable implements MainThrea clear(accountKey: azdata.AccountKey): Thenable { return self._proxy.$clear(handle, accountKey); }, + getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{}> { return self._proxy.$getSecurityToken(account, resource); }, + getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> { + return self._proxy.$getAccountSecurityToken(account, tenant, resource); + }, initialize(restoredAccounts: azdata.Account[]): Thenable { return self._proxy.$initialize(handle, restoredAccounts); }, diff --git a/src/sql/workbench/api/common/extHostAccountManagement.ts b/src/sql/workbench/api/common/extHostAccountManagement.ts index 66e33cdf4e..2f6bef5732 100644 --- a/src/sql/workbench/api/common/extHostAccountManagement.ts +++ b/src/sql/workbench/api/common/extHostAccountManagement.ts @@ -89,10 +89,7 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape { return Promise.all(promises).then(() => resultAccounts); } - public $getSecurityToken(account: azdata.Account, resource?: azdata.AzureResource): Thenable<{}> { - if (resource === undefined) { - resource = AzureResource.ResourceManagement; - } + public $getSecurityToken(account: azdata.Account, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{}> { return this.$getAllAccounts().then(() => { for (const handle in this._accounts) { const providerHandle = parseInt(handle); @@ -105,6 +102,20 @@ export class ExtHostAccountManagement extends ExtHostAccountManagementShape { }); } + public $getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource = AzureResource.ResourceManagement): Thenable<{ token: string }> { + return this.$getAllAccounts().then(() => { + for (const handle in this._accounts) { + const providerHandle = parseInt(handle); + if (firstIndex(this._accounts[handle], (acct) => acct.key.accountId === account.key.accountId) !== -1) { + return this._withProvider(providerHandle, (provider: azdata.AccountProvider) => provider.getAccountSecurityToken(account, tenant, resource)); + } + } + + throw new Error(`Account ${account.key.accountId} not found.`); + }); + } + + public get onDidChangeAccounts(): Event { return this._onDidChangeAccounts.event; } diff --git a/src/sql/workbench/api/common/sqlExtHost.api.impl.ts b/src/sql/workbench/api/common/sqlExtHost.api.impl.ts index 22f0e7f2e8..625554e126 100644 --- a/src/sql/workbench/api/common/sqlExtHost.api.impl.ts +++ b/src/sql/workbench/api/common/sqlExtHost.api.impl.ts @@ -160,6 +160,9 @@ export function createAdsApiFactory(accessor: ServicesAccessor): IAdsExtensionAp getSecurityToken(account: azdata.Account, resource?: azdata.AzureResource): Thenable<{}> { return extHostAccountManagement.$getSecurityToken(account, resource); }, + getAccountSecurityToken(account: azdata.Account, tenant: string, resource?: azdata.AzureResource): Thenable<{ token: string }> { + return extHostAccountManagement.$getAccountSecurityToken(account, tenant, resource); + }, onDidChangeAccounts(listener: (e: azdata.DidChangeAccountsParams) => void, thisArgs?: any, disposables?: extHostTypes.Disposable[]) { return extHostAccountManagement.onDidChangeAccounts(listener, thisArgs, disposables); } diff --git a/src/sql/workbench/api/common/sqlExtHost.protocol.ts b/src/sql/workbench/api/common/sqlExtHost.protocol.ts index 1dfd51fc31..f955a55345 100644 --- a/src/sql/workbench/api/common/sqlExtHost.protocol.ts +++ b/src/sql/workbench/api/common/sqlExtHost.protocol.ts @@ -31,6 +31,7 @@ export abstract class ExtHostAccountManagementShape { $autoOAuthCancelled(handle: number): Thenable { throw ni(); } $clear(handle: number, accountKey: azdata.AccountKey): Thenable { throw ni(); } $getSecurityToken(account: azdata.Account, resource?: azdata.AzureResource): Thenable<{}> { throw ni(); } + $getAccountSecurityToken(account: azdata.Account, tenant: string, resource?: azdata.AzureResource): Thenable<{ token: string }> { throw ni(); } $initialize(handle: number, restoredAccounts: azdata.Account[]): Thenable { throw ni(); } $prompt(handle: number): Thenable { throw ni(); } $refresh(handle: number, account: azdata.Account): Thenable { throw ni(); } diff --git a/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts b/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts index 169b697791..1abc3a09d0 100644 --- a/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts +++ b/src/sql/workbench/services/accountManagement/browser/accountManagementService.ts @@ -244,6 +244,19 @@ export class AccountManagementService implements IAccountManagementService { }); } + /** + * Generates a security token by asking the account's provider + * @param account Account to generate security token for + * @param tenant Tenant to generate security token for + * @param resource The resource to get the security token for + * @return Promise to return the security token + */ + public getAccountSecurityToken(account: azdata.Account, tenant: string, resource: azdata.AzureResource): Thenable<{ token: string }> { + return this.doWithProvider(account.key.providerId, provider => { + return provider.provider.getAccountSecurityToken(account, tenant, resource); + }); + } + /** * Removes an account from the account store and clears sensitive data in the provider * @param accountKey Key for the account to remove diff --git a/src/sql/workbench/services/connection/browser/connectionManagementService.ts b/src/sql/workbench/services/connection/browser/connectionManagementService.ts index 9bbffae813..154a0b0550 100644 --- a/src/sql/workbench/services/connection/browser/connectionManagementService.ts +++ b/src/sql/workbench/services/connection/browser/connectionManagementService.ts @@ -813,8 +813,8 @@ export class ConnectionManagementService extends Disposable implements IConnecti const accounts = await this._accountManagementService.getAccounts(); const azureAccounts = accounts.filter(a => a.key.providerId.startsWith('azure')); if (azureAccounts && azureAccounts.length > 0) { - let accountName = (connection.authenticationType === Constants.azureMFA || connection.authenticationType === Constants.azureMFAAndUser) ? connection.azureAccount : connection.userName; - let account = find(azureAccounts, account => account.key.accountId === accountName); + let accountId = (connection.authenticationType === Constants.azureMFA || connection.authenticationType === Constants.azureMFAAndUser) ? connection.azureAccount : connection.userName; + let account = find(azureAccounts, account => account.key.accountId === accountId); if (account) { this._logService.debug(`Getting security token for Azure account ${account.key.accountId}`); if (account.isStale) { @@ -827,26 +827,17 @@ export class ConnectionManagementService extends Disposable implements IConnecti return false; } } - const tokensByTenant = await this._accountManagementService.getSecurityToken(account, azureResource); - this._logService.debug(`Got tokens for tenants [${Object.keys(tokensByTenant).join(',')}]`); - let token: string; const tenantId = connection.azureTenantId; - if (tenantId && tokensByTenant[tenantId]) { - token = tokensByTenant[tenantId].token; - } else { - this._logService.debug(`No security token found for specific tenant ${tenantId} - falling back to first one`); - const tokens = values(tokensByTenant); - if (tokens.length === 0) { - this._logService.info(`No security tokens found for account`); - return false; - } - token = tokens[0].token; + const token = await this._accountManagementService.getAccountSecurityToken(account, tenantId, azureResource); + this._logService.debug(`Got token for tenant ${token}`); + if (!token) { + this._logService.info(`No security tokens found for account`); } - connection.options['azureAccountToken'] = token; + connection.options['azureAccountToken'] = token.token; connection.options['password'] = ''; return true; } else { - this._logService.info(`Could not find Azure account with name ${accountName}`); + this._logService.info(`Could not find Azure account with name ${accountId}`); } } else { this._logService.info(`Could not find any Azure accounts from accounts : [${accounts.map(a => `${a.key.accountId} (${a.key.providerId})`).join(',')}]`); diff --git a/src/sql/workbench/services/connection/browser/connectionWidget.ts b/src/sql/workbench/services/connection/browser/connectionWidget.ts index 1f09b1baf3..ff908611c2 100644 --- a/src/sql/workbench/services/connection/browser/connectionWidget.ts +++ b/src/sql/workbench/services/connection/browser/connectionWidget.ts @@ -6,7 +6,7 @@ import 'vs/css!./media/sqlConnection'; import { Button } from 'sql/base/browser/ui/button/button'; -import { SelectBox } from 'sql/base/browser/ui/selectBox/selectBox'; +import { SelectBox, SelectOptionItemSQL } from 'sql/base/browser/ui/selectBox/selectBox'; import { Checkbox } from 'sql/base/browser/ui/checkbox/checkbox'; import { InputBox } from 'sql/base/browser/ui/inputBox/inputBox'; import * as DialogHelper from 'sql/workbench/browser/modal/dialogHelper'; @@ -520,12 +520,18 @@ export class ConnectionWidget extends lifecycle.Disposable { let oldSelection = this._azureAccountDropdown.value; const accounts = await this._accountManagementService.getAccounts(); this._azureAccountList = accounts.filter(a => a.key.providerId.startsWith('azure')); - let accountDropdownOptions = this._azureAccountList.map(account => account.displayInfo.displayName); + let accountDropdownOptions: SelectOptionItemSQL[] = this._azureAccountList.map(account => { + return { + text: account.displayInfo.displayName, + value: account.key.accountId + } as SelectOptionItemSQL; + }); + if (accountDropdownOptions.length === 0) { // If there are no accounts add a blank option so that add account isn't automatically selected - accountDropdownOptions.unshift(''); + accountDropdownOptions.unshift({ text: '', value: '' }); } - accountDropdownOptions.push(this._addAzureAccountMessage); + accountDropdownOptions.push({ text: this._addAzureAccountMessage, value: this._addAzureAccountMessage }); this._azureAccountDropdown.setOptions(accountDropdownOptions); this._azureAccountDropdown.selectWithOptionName(oldSelection); } diff --git a/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts b/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts index e9f9288cfa..009948aa06 100644 --- a/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts +++ b/src/sql/workbench/services/connection/test/browser/connectionManagementService.test.ts @@ -1299,6 +1299,7 @@ suite('SQL ConnectionManagementService tests', () => { let servername = 'test-database.database.windows.net'; azureConnectionProfile.serverName = servername; let providerId = 'azure_PublicCloud'; + azureConnectionProfile.azureTenantId = 'testTenant'; // Set up the account management service to return a token for the given user accountManagementService.setup(x => x.getAccountsForProvider(TypeMoq.It.isAny())).returns(providerId => Promise.resolve([ @@ -1327,10 +1328,9 @@ suite('SQL ConnectionManagementService tests', () => { ]); }); let testToken = 'testToken'; - accountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ - azure_publicCloud: { - token: testToken - } + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ + token: testToken, + tokenType: 'Bearer' })); connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({ profile: profile, @@ -1384,11 +1384,8 @@ suite('SQL ConnectionManagementService tests', () => { ]); }); - let testToken = 'testToken'; - let returnedTokens = {}; - returnedTokens['azure_publicCloud'] = { token: 'badToken' }; - returnedTokens[azureTenantId] = { token: testToken }; - accountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(returnedTokens)); + let returnedToken = { token: 'testToken', tokenType: 'Bearer' }; + accountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(returnedToken)); connectionStore.setup(x => x.addSavedPassword(TypeMoq.It.is(profile => profile.authenticationType === 'AzureMFA'))).returns(profile => Promise.resolve({ profile: profile, savedCred: false @@ -1399,7 +1396,7 @@ suite('SQL ConnectionManagementService tests', () => { // Then the returned profile has the account token set corresponding to the requested tenant assert.equal(profileWithCredentials.userName, azureConnectionProfile.userName); - assert.equal(profileWithCredentials.options['azureAccountToken'], testToken); + assert.equal(profileWithCredentials.options['azureAccountToken'], returnedToken.token); }); test('getConnections test', () => { diff --git a/src/sql/workbench/services/resourceProvider/browser/firewallRuleDialogController.ts b/src/sql/workbench/services/resourceProvider/browser/firewallRuleDialogController.ts index 9b0a134064..06e41b4b75 100644 --- a/src/sql/workbench/services/resourceProvider/browser/firewallRuleDialogController.ts +++ b/src/sql/workbench/services/resourceProvider/browser/firewallRuleDialogController.ts @@ -60,7 +60,12 @@ export class FirewallRuleDialogController { private async handleOnCreateFirewallRule(): Promise { const resourceProviderId = this._resourceProviderId; try { - const securityTokenMappings = await this._accountManagementService.getSecurityToken(this._firewallRuleDialog.viewModel.selectedAccount!, AzureResource.ResourceManagement); + const tenantId = this._connection.azureTenantId; + const token = await this._accountManagementService.getAccountSecurityToken(this._firewallRuleDialog.viewModel.selectedAccount!, tenantId, AzureResource.ResourceManagement); + const securityTokenMappings = { + [tenantId]: token + }; + const firewallRuleInfo: azdata.FirewallRuleInfo = { startIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.fromSubnetIPRange, endIpAddress: this._firewallRuleDialog.viewModel.isIPAddressSelected ? this._firewallRuleDialog.viewModel.defaultIPAddress : this._firewallRuleDialog.viewModel.toSubnetIPRange, diff --git a/src/sql/workbench/services/resourceProvider/test/browser/firewallRuleDialogController.test.ts b/src/sql/workbench/services/resourceProvider/test/browser/firewallRuleDialogController.test.ts index 5c16c968fc..afdd9fc805 100644 --- a/src/sql/workbench/services/resourceProvider/test/browser/firewallRuleDialogController.test.ts +++ b/src/sql/workbench/services/resourceProvider/test/browser/firewallRuleDialogController.test.ts @@ -92,7 +92,8 @@ suite('Firewall rule dialog controller tests', () => { providerName: mssqlProviderName, options: {}, saveProfile: true, - id: '' + id: '', + azureTenantId: 'someTenant' }; }); @@ -137,7 +138,7 @@ suite('Firewall rule dialog controller tests', () => { // Then: it should get security token from account management service and call create firewall rule in resource provider await deferredPromise; - mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); + mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockFirewallRuleDialog.verify(x => x.close(), TypeMoq.Times.once()); mockFirewallRuleDialog.verify(x => x.onServiceComplete(), TypeMoq.Times.once()); @@ -164,7 +165,7 @@ suite('Firewall rule dialog controller tests', () => { // Then: it should get security token from account management service and an error dialog should have been opened await deferredPromise; - mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); + mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.never()); }); @@ -191,7 +192,7 @@ suite('Firewall rule dialog controller tests', () => { // Then: it should get security token from account management service and an error dialog should have been opened await deferredPromise; - mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); + mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); }); @@ -217,7 +218,7 @@ suite('Firewall rule dialog controller tests', () => { // Then: it should get security token from account management service and an error dialog should have been opened await deferredPromise; - mockAccountManagementService.verify(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); + mockAccountManagementService.verify(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockResourceProvider.verify(x => x.createFirewallRule(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); mockErrorMessageService.verify(x => x.showDialog(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); }); @@ -226,8 +227,8 @@ suite('Firewall rule dialog controller tests', () => { function getMockAccountManagementService(resolveSecurityToken: boolean): TypeMoq.Mock { let accountManagementTestService = new TestAccountManagementService(); let mockAccountManagementService = TypeMoq.Mock.ofInstance(accountManagementTestService); - mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny())) - .returns(() => resolveSecurityToken ? Promise.resolve({}) : Promise.reject(null)); + mockAccountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())) + .returns(() => resolveSecurityToken ? Promise.resolve({ token: 'token' }) : Promise.reject(null)); return mockAccountManagementService; } diff --git a/src/sql/workbench/test/electron-browser/api/extHostAccountManagement.test.ts b/src/sql/workbench/test/electron-browser/api/extHostAccountManagement.test.ts index efa0e1049e..07b8720dc2 100644 --- a/src/sql/workbench/test/electron-browser/api/extHostAccountManagement.test.ts +++ b/src/sql/workbench/test/electron-browser/api/extHostAccountManagement.test.ts @@ -442,6 +442,8 @@ function getMockAccountManagementService(accounts: azdata.Account[]): TypeMoq.Mo .returns(() => Promise.resolve(accounts)); mockAccountManagementService.setup(x => x.getSecurityToken(TypeMoq.It.isValue(accounts[0]), TypeMoq.It.isAny())) .returns(() => Promise.resolve({})); + mockAccountManagementService.setup(x => x.getAccountSecurityToken(TypeMoq.It.isValue(accounts[0]), TypeMoq.It.isAny(), TypeMoq.It.isAny())) + .returns(() => Promise.resolve(undefined)); mockAccountManagementService.setup(x => x.updateAccountListEvent) .returns(() => () => { return undefined; });