Introduce AAD interactive auth mode (#1860)

This commit is contained in:
Cheena Malhotra
2023-03-02 09:39:54 -08:00
committed by GitHub
parent 98e50c98fe
commit 187b6ecc14
47 changed files with 918 additions and 151 deletions

View File

@@ -0,0 +1,33 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Represents an access token data object.
/// </summary>
public class AccessToken
{
/// <summary>
/// OAuth 2.0 JWT encoded access token string
/// </summary>
public string Token { get; set; }
/// <summary>
/// Expiry date of token
/// </summary>
public DateTimeOffset ExpiresOn { get; set; }
/// <summary>
/// Default constructor for Access Token object
/// </summary>
/// <param name="token">Access token as string</param>
/// <param name="expiresOn">Expiry date</param>
public AccessToken(string token, DateTimeOffset expiresOn) {
this.Token = token;
this.ExpiresOn = expiresOn;
}
}
}

View File

@@ -0,0 +1,15 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Supported Active Directory authentication modes
/// </summary>
public enum AuthenticationMethod
{
ActiveDirectoryInteractive
}
}

View File

@@ -0,0 +1,82 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Parameters to be passed to <see cref="Authenticator"/> to request an access token
/// </summary>
public class AuthenticationParams
{
/// <summary>
/// Authentication method to be used by <see cref="Authenticator"/>.
/// </summary>
public AuthenticationMethod AuthenticationMethod { get; set; }
/// <summary>
/// Authority URL, e.g. https://login.microsoftonline.com/
/// </summary>
public string Authority { get; set; }
/// <summary>
/// Audience for which access token should be acquired, e.g. common, organizations, consumers.
/// It can also be a tenant Id when authenticating multi-tenant application accounts.
/// </summary>
public string Audience { get; set; }
/// <summary>
/// Resource URL, e.g. https://database.windows.net/
/// </summary>
public string Resource { get; set; }
/// <summary>
/// Array of scopes for which access token is requested.
/// </summary>
public string[] Scopes { get; set; }
/// <summary>
/// <see cref="Guid"/> Connection Id, that will be passed to Azure AD when requesting access token.
/// It can be used for tracking AAD request status if needed.
/// </summary>
public Guid ConnectionId { get; set; }
/// <summary>
/// User name to be provided as userhint when acquiring access token.
/// </summary>
public string UserName { get; set; }
/// <summary>
/// Default constructor
/// </summary>
/// <param name="authMethod">Authentication Method to be used.</param>
/// <param name="authority">Authority URL</param>
/// <param name="audience">Audience</param>
/// <param name="resource">Resource for which token is requested.</param>
/// <param name="scopes">Scopes for access token</param>
/// <param name="userName">User hint information</param>
/// <param name="connectionId">Connection Id for tracing AAD request</param>
public AuthenticationParams(AuthenticationMethod authMethod, string authority, string audience,
string resource, string[] scopes, string userName, Guid connectionId) {
this.AuthenticationMethod = authMethod;
this.Authority = authority;
this.Audience = audience;
this.Resource = resource;
this.Scopes = scopes;
this.UserName = userName;
this.ConnectionId = connectionId;
}
public string ToLogString(bool piiEnabled)
{
return $"\tAuthenticationMethod: {AuthenticationMethod}" +
$"\n\tAuthority: {Authority}" +
$"\n\tAudience: {Audience}" +
$"\n\tResource: {Resource}" +
$"\n\tScopes: {Scopes.ToString()}" +
(piiEnabled ? $"\n\tUsername: {UserName}" : "") +
$"\n\tConnection Id: {ConnectionId}";
}
}
}

View File

@@ -0,0 +1,158 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System.Collections.Concurrent;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Extensions.Msal;
using Microsoft.SqlTools.Authentication.Utility;
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
namespace Microsoft.SqlTools.Authentication
{
/// <summary>
/// Provides APIs to acquire access token using MSAL.NET v4 with provided <see cref="AuthenticationParams"/>.
/// </summary>
public class Authenticator
{
private string applicationClientId;
private string applicationName;
private string cacheFolderPath;
private string cacheFileName;
private MsalCacheHelper cacheHelper;
private static ConcurrentDictionary<string, IPublicClientApplication> PublicClientAppMap
= new ConcurrentDictionary<string, IPublicClientApplication>();
#region Public APIs
public Authenticator(string appClientId, string appName, string cacheFolderPath, string cacheFileName)
{
this.applicationClientId = appClientId;
this.applicationName = appName;
this.cacheFolderPath = cacheFolderPath;
this.cacheFileName = cacheFileName;
// Storage creation properties are used to enable file system caching with MSAL
var storageCreationProperties = new StorageCreationPropertiesBuilder(this.cacheFileName, this.cacheFolderPath)
.WithUnprotectedFile().Build();
// This hooks up the cross-platform cache into MSAL
this.cacheHelper = MsalCacheHelper.CreateAsync(storageCreationProperties).ConfigureAwait(false).GetAwaiter().GetResult();
}
/// <summary>
/// Acquires access token synchronously.
/// </summary>
/// <param name="params">Authentication parameters to be used for access token request.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Access Token with expiry date</returns>
public AccessToken? GetToken(AuthenticationParams @params, CancellationToken cancellationToken) =>
GetTokenAsync(@params, cancellationToken).GetAwaiter().GetResult();
/// <summary>
/// Acquires access token asynchronously.
/// </summary>
/// <param name="params">Authentication parameters to be used for access token request.</param>
/// <param name="cancellationToken">Cancellation token.</param>
/// <returns>Access Token with expiry date</returns>
/// <exception cref="Exception"></exception>
public async Task<AccessToken?> GetTokenAsync(AuthenticationParams @params, CancellationToken cancellationToken)
{
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Received @params: {@params.ToLogString(SqlToolsLogger.IsPiiEnabled)}");
IPublicClientApplication publicClientApplication = GetPublicClientAppInstance(@params.Authority, @params.Audience);
AccessToken? accessToken;
if (@params.AuthenticationMethod == AuthenticationMethod.ActiveDirectoryInteractive)
{
// Find account
IEnumerator<IAccount>? accounts = (await publicClientApplication.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
IAccount? account = default;
if (!string.IsNullOrEmpty(@params.UserName) && accounts.MoveNext())
{
// Handle username format to extract email: "John Doe - johndoe@constoso.com"
string username = @params.UserName.Contains(" - ") ? @params.UserName.Split(" - ")[1] : @params.UserName;
if (!Utils.isValidEmail(username))
{
SqlToolsLogger.Pii($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Unexpected username format, email not retreived: {@params.UserName}. " +
$"Accepted formats are: 'johndoe@org.com' or 'John Doe - johndoe@org.com'.");
throw new Exception($"Invalid email address format for user: [{username}] received for Azure Active Directory authentication.");
}
do
{
IAccount? currentVal = accounts.Current;
if (string.Compare(username, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
{
account = currentVal;
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account found in MSAL Cache: {account.HomeAccountId}");
break;
}
}
while (accounts.MoveNext());
if (null != account)
{
try
{
// Fetch token silently
var result = await publicClientApplication.AcquireTokenSilent(@params.Scopes, account)
.ExecuteAsync(cancellationToken: cancellationToken)
.ConfigureAwait(false);
accessToken = new AccessToken(result!.AccessToken, result!.ExpiresOn);
}
catch (Exception e)
{
SqlToolsLogger.Verbose($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Silent authentication failed for resource {@params.Resource} for ConnectionId {@params.ConnectionId}.");
SqlToolsLogger.Error(e);
throw;
}
}
else
{
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Account not found in MSAL cache for user.");
throw new Exception($"User account '{username}' not found in MSAL cache, please add linked account or refresh account credentials.");
}
}
else
{
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | User account not received.");
throw new Exception($"User account not received.");
}
}
else
{
SqlToolsLogger.Error($"{nameof(Authenticator)}.{nameof(GetTokenAsync)} | Authentication Method ${@params.AuthenticationMethod} is not supported.");
throw new Exception($"Authentication Method ${@params.AuthenticationMethod} is not supported.");
}
return accessToken;
}
#endregion
#region Private methods
private IPublicClientApplication GetPublicClientAppInstance(string authority, string audience)
{
string authorityUrl = authority + '/' + audience;
if (!PublicClientAppMap.TryGetValue(authorityUrl, out IPublicClientApplication? clientApplicationInstance))
{
clientApplicationInstance = CreatePublicClientAppInstance(authority, audience);
this.cacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache);
PublicClientAppMap.TryAdd(authorityUrl, clientApplicationInstance);
}
return clientApplicationInstance;
}
private IPublicClientApplication CreatePublicClientAppInstance(string authority, string audience) =>
PublicClientApplicationBuilder.Create(this.applicationClientId)
.WithAuthority(authority, audience)
.WithClientName(this.applicationName)
.WithLogging(Utils.MSALLogCallback)
.WithDefaultRedirectUri()
.Build();
#endregion
}
}

View File

@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient" />
<PackageReference Include="Microsoft.Identity.Client" />
<PackageReference Include="Microsoft.Identity.Client.Extensions.Msal" />
<PackageReference Include="System.Configuration.ConfigurationManager" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.SqlTools.Shared\Microsoft.SqlTools.Shared.csproj" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,110 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.Data.SqlClient;
using Microsoft.SqlTools.Authentication.Utility;
namespace Microsoft.SqlTools.Authentication.Sql
{
/// <summary>
/// Provides an implementation of <see cref="SqlAuthenticationProvider"/> for SQL Tools to be able to perform Federated authentication
/// silently with Microsoft.Data.SqlClient integration only for "ActiveDirectory" authentication modes.
/// When registered, the SqlClient driver calls the <see cref="AcquireTokenAsync(SqlAuthenticationParameters)"/> API
/// with server-sent authority information to request access token when needed.
/// </summary>
public class AuthenticationProvider : SqlAuthenticationProvider
{
private const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255";
private const string AzureTokenFolder = "Azure Accounts";
private const string MsalCacheName = "azureTokenCacheMsal-azure_publicCloud";
private const string s_defaultScopeSuffix = "/.default";
private Authenticator authenticator;
/// <summary>
/// Instantiates AuthenticationProvider to be used for AAD authentication with MSAL.NET and MSAL.js co-ordinated.
/// </summary>
/// <param name="applicationName">Application Name that identifies user folder path location for reading/writing to shared cache.</param>
/// <param name="authCallback">Callback that handles AAD authentication when user interaction is needed.</param>
public AuthenticationProvider(string applicationName)
{
if(string.IsNullOrEmpty(applicationName)) {
applicationName = nameof(SqlTools);
}
var cachePath = Path.Combine(Utils.BuildAppDirectoryPath(), applicationName, AzureTokenFolder);
this.authenticator = new Authenticator(ApplicationClientId, applicationName, cachePath, MsalCacheName);
}
/// <summary>
/// Acquires access token with provided <paramref name="parameters"/>
/// </summary>
/// <param name="parameters">Authentication parameters</param>
/// <returns>Authentication token containing access token and expiry date.</returns>
public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters)
{
// Setup scope
string resource = parameters.Resource;
string scope;
if (parameters.Resource.EndsWith(s_defaultScopeSuffix))
{
scope = parameters.Resource;
resource = parameters.Resource.Substring(0, parameters.Resource.LastIndexOf('/') + 1);
}
else
{
scope = parameters.Resource + s_defaultScopeSuffix;
}
string[] scopes = new string[] { scope };
CancellationTokenSource cts = new CancellationTokenSource();
// Use Connection timeout value to cancel token acquire request after certain period of time.
cts.CancelAfter(parameters.ConnectionTimeout * 1000); // Convert to milliseconds
/* We split audience from Authority URL here. Audience can be one of the following:
* The Azure AD authority audience enumeration
* The tenant ID, which can be:
* - A GUID (the ID of your Azure AD instance), for single-tenant applications
* - A domain name associated with your Azure AD instance (also for single-tenant applications)
* One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration:
* - `organizations` for a multitenant application
* - `consumers` to sign in users only with their personal accounts
* - `common` to sign in users with their work and school accounts or their personal Microsoft accounts
*
* MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID.
* If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.)
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
**/
int seperatorIndex = parameters.Authority.LastIndexOf('/');
string authority = parameters.Authority.Remove(seperatorIndex + 1);
string audience = parameters.Authority.Substring(seperatorIndex + 1);
string? userName = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;
AuthenticationParams @params = new AuthenticationParams(
AuthenticationMethod.ActiveDirectoryInteractive,
authority,
audience,
resource,
scopes,
userName!,
parameters.ConnectionId);
AccessToken? result = await authenticator.GetTokenAsync(@params, cts.Token).ConfigureAwait(false);
return new SqlAuthenticationToken(result!.Token, result!.ExpiresOn);
}
/// <summary>
/// Whether or not provided <paramref name="authenticationMethod"/> is supported.
/// </summary>
/// <param name="authenticationMethod">SQL Authentication method</param>
/// <returns></returns>
public override bool IsSupported(SqlAuthenticationMethod authenticationMethod)
=> authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive;
}
}

View File

@@ -0,0 +1,119 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System.Net.Mail;
using System.Runtime.InteropServices;
using Microsoft.Identity.Client;
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
namespace Microsoft.SqlTools.Authentication.Utility
{
internal sealed class Utils
{
/// <summary>
/// Validates provided <paramref name="userEmail"/> follows email format.
/// </summary>
/// <param name="useremail">Email address</param>
/// <returns>Whether email is in correct format.</returns>
public static bool isValidEmail(string userEmail)
{
try
{
new MailAddress(userEmail);
return true;
}
catch (FormatException)
{
return false;
}
}
/// <summary>
/// Builds directory path based on environment settings.
/// </summary>
/// <returns>Application directory path</returns>
/// <exception cref="Exception">When called on unsupported platform.</exception>
public static string BuildAppDirectoryPath()
{
var homedir = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
// Windows
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
var appData = Environment.GetEnvironmentVariable("APPDATA");
var userProfile = Environment.GetEnvironmentVariable("USERPROFILE");
if (appData != null)
{
return appData;
}
else if (userProfile != null)
{
return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming");
}
else
{
throw new Exception("Not able to find APPDATA or USERPROFILE");
}
}
// Mac
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
return string.Join(homedir, "Library", "Application Support");
}
// Linux
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME");
if (xdgConfigHome != null)
{
return xdgConfigHome;
}
else
{
return string.Join(homedir, ".config");
}
}
else
{
throw new Exception("Platform not supported");
}
}
/// <summary>
/// Log callback handler used for MSAL Client applications.
/// </summary>
/// <param name="logLevel">Log level</param>
/// <param name="message">Log message</param>
/// <param name="pii">Whether message contains PII information.</param>
public static void MSALLogCallback(LogLevel logLevel, string message, bool pii)
{
switch (logLevel)
{
case LogLevel.Error:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Error(message);
break;
case LogLevel.Warning:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Warning(message);
break;
case LogLevel.Info:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Information(message);
break;
case LogLevel.Verbose:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Verbose(message);
break;
case LogLevel.Always:
if (pii) SqlToolsLogger.Pii(message);
else SqlToolsLogger.Critical(message);
break;
}
}
}
}