diff --git a/extensions/azurecore/src/account-provider/auths/azureAuth.ts b/extensions/azurecore/src/account-provider/auths/azureAuth.ts index 3b82745571..1d3a350a4d 100644 --- a/extensions/azurecore/src/account-provider/auths/azureAuth.ts +++ b/extensions/azurecore/src/account-provider/auths/azureAuth.ts @@ -145,6 +145,10 @@ export abstract class AzureAuth implements vscode.Disposable { return account; } + private verifyCorrectToken(responseToken: OAuthTokenResponse, account: AzureAccount): boolean { + return this.getUserKey(responseToken.tokenClaims) === account.key.accountId; + } + public async getAccountSecurityToken(account: AzureAccount, tenantId: string, azureResource: azdata.AzureResource): Promise { if (account.isStale === true) { Logger.log('Account was stale. No tokens being fetched.'); @@ -181,6 +185,11 @@ export abstract class AzureAuth implements vscode.Disposable { if (remainingTime < maxTolerance) { const result = await this.refreshToken(tenant, resource, cachedTokens.refreshToken); + + // Verify that the user logged into this account + if (!this.verifyCorrectToken(result, account)) { + return undefined; + } accessToken = result.accessToken; } // Let's just return here. @@ -203,6 +212,10 @@ export abstract class AzureAuth implements vscode.Disposable { } // Let's try to convert the access token type, worst case we'll have to prompt the user to do an interactive authentication. const result = await this.refreshToken(tenant, resource, baseTokens.refreshToken); + // Verify that the user logged into this account + if (!this.verifyCorrectToken(result, account)) { + return undefined; + } if (result.accessToken) { return { ...result.accessToken, @@ -258,6 +271,26 @@ export abstract class AzureAuth implements vscode.Disposable { return this.getTokenHelper(tenant, resource, accessTokenString, refreshTokenString, expiresOnString); } + public getUserKey(tokenClaims: TokenClaims): string { + // Personal accounts don't have an oid when logging into the `common` tenant, but when logging into their home tenant they end up having an oid. + // This makes the key for the same account be different. + // We need to special case personal accounts. + + let userKey: string; + if (tokenClaims.idp === 'live.com') { // Personal account + userKey = tokenClaims.unique_name ?? tokenClaims.email ?? tokenClaims.sub; + } else { + userKey = tokenClaims.home_oid ?? tokenClaims.oid ?? tokenClaims.unique_name ?? tokenClaims.email ?? tokenClaims.sub; + } + + if (!userKey) { + Logger.pii(tokenClaims); + throw new AzureAuthError(localize('azure.userKeyUndefined', "User key was undefined - could not create a userKey from the tokenClaims"), 'user key undefined', undefined); + } + + return userKey; + } + public async getTokenHelper(tenant: Tenant, resource: Resource, accessTokenString: string, refreshTokenString: string, expiresOnString: string): Promise { if (!accessTokenString) { const msg = localize('azure.accessTokenEmpty', 'No access token returned from Microsoft OAuth'); @@ -265,16 +298,7 @@ export abstract class AzureAuth implements vscode.Disposable { } const tokenClaims: TokenClaims = this.getTokenClaims(accessTokenString); - let userKey: string; - - // Personal accounts don't have an oid when logging into the `common` tenant, but when logging into their home tenant they end up having an oid. - // This makes the key for the same account be different. - // We need to special case personal accounts. - if (tokenClaims.idp === 'live.com') { // Personal account - userKey = tokenClaims.unique_name ?? tokenClaims.email ?? tokenClaims.sub; - } else { - userKey = tokenClaims.home_oid ?? tokenClaims.oid ?? tokenClaims.unique_name ?? tokenClaims.email ?? tokenClaims.sub; - } + const userKey = this.getUserKey(tokenClaims); if (!userKey) { const msg = localize('azure.noUniqueIdentifier', "The user had no unique identifier within AAD"); @@ -415,7 +439,6 @@ export abstract class AzureAuth implements vscode.Disposable { if (shouldOpen) { const result = await this.login(tenant, resource); result?.authComplete?.resolve(); - return result?.response; } return undefined; } diff --git a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts index 0929ca5006..46d6f87241 100644 --- a/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts +++ b/extensions/azurecore/src/test/account-provider/auths/azureAuth.test.ts @@ -32,7 +32,9 @@ let mockRefreshToken: RefreshToken; const mockClaims = { name: 'Name', email: 'example@example.com', - sub: 'someUniqueId' + sub: 'someUniqueId', + idp: 'idp', + oid: 'userUniqueKey' } as TokenClaims; const mockTenant: Tenant = { @@ -55,6 +57,9 @@ describe('Azure Authentication', function () { // authDeviceCode.callBase = true; mockAccount = { + key: { + accountId: mockClaims.oid + }, isStale: false, properties: { tenants: [mockTenant] @@ -159,7 +164,8 @@ describe('Azure Authentication', function () { const mockToken: AccessToken = JSON.parse(JSON.stringify(mockAccessToken)); delete (mockToken as any).invalidData; return Promise.resolve({ - accessToken: mockToken + accessToken: mockToken, + tokenClaims: mockClaims } as OAuthTokenResponse); }); const securityToken = await azureAuthCodeGrant.object.getAccountSecurityToken(mockAccount, mockTenant.id, AzureResource.MicrosoftResourceManagement); @@ -192,14 +198,15 @@ describe('Azure Authentication', function () { return Promise.resolve({ accessToken: mockAccessToken, refreshToken: mockRefreshToken, - expiresOn: '' + expiresOn: '', }); }); delete (mockAccessToken as any).tokenType; azureAuthCodeGrant.setup(x => x.refreshToken(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { return Promise.resolve({ - accessToken: mockAccessToken + accessToken: mockAccessToken, + tokenClaims: mockClaims, } as OAuthTokenResponse); }); @@ -273,4 +280,77 @@ describe('Azure Authentication', function () { }); }); + describe('getUserKey', function () { + let tokenClaims: TokenClaims; + beforeEach(function () { + tokenClaims = { + unique_name: 'unique_name', + email: 'email', + sub: 'sub', + oid: 'oid', + home_oid: 'home_oid' + } as TokenClaims; + }); + describe('personal accounts - live.com', function () { + beforeEach(function () { + tokenClaims.idp = 'live.com'; + }); + it('prefer unique_name', function () { + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.unique_name); + }); + it('fallback to email', function () { + delete tokenClaims.unique_name; + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.email); + }); + it('fallback to sub', function () { + delete tokenClaims.unique_name; + delete tokenClaims.email; + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.sub); + }); + }); + describe('work accounts', function () { + it('prefer home_oid', function () { + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.home_oid); + }); + it('fallback to oid', function () { + delete tokenClaims.home_oid; + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.oid); + }); + it('fallback to unique_name', function () { + delete tokenClaims.home_oid; + delete tokenClaims.oid; + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.unique_name); + }); + it('fallback to email', function () { + delete tokenClaims.home_oid; + delete tokenClaims.oid; + delete tokenClaims.unique_name; + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.email); + }); + it('fallback to sub', function () { + delete tokenClaims.home_oid; + delete tokenClaims.oid; + delete tokenClaims.unique_name; + delete tokenClaims.email; + const value = azureAuthCodeGrant.object.getUserKey(tokenClaims); + + should(value).be.equal(tokenClaims.sub); + }); + }); + }); + });