From 4d9cb17c9316f56685c800f4abe8aaa5cd93b626 Mon Sep 17 00:00:00 2001 From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com> Date: Thu, 16 Mar 2023 11:56:35 -0700 Subject: [PATCH] MSAL encrypted file system cache (#1945) --- .../Authenticator.cs | 33 +-- .../IAuthenticator.cs | 27 ++ .../MSALEncryptedCacheHelper.cs | 261 ++++++++++++++++++ .../Sql/AuthenticationProvider.cs | 24 +- .../Utility/AuthenticatorConfiguration.cs | 41 +++ .../Utility/Encryption/EncryptionUtils.cs | 76 +++++ .../Utility/Utils.cs | 54 ---- .../Connection/ConnectionService.cs | 79 +++++- .../Workspace/WorkspaceService.cs | 2 +- .../Utility/Constants.cs | 6 + .../Utility/Utils.cs | 65 +++++ 11 files changed, 568 insertions(+), 100 deletions(-) create mode 100644 src/Microsoft.SqlTools.Authentication/IAuthenticator.cs create mode 100644 src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs create mode 100644 src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs create mode 100644 src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs create mode 100644 src/Microsoft.SqlTools.Shared/Utility/Utils.cs diff --git a/src/Microsoft.SqlTools.Authentication/Authenticator.cs b/src/Microsoft.SqlTools.Authentication/Authenticator.cs index d149a85f..e3f279a0 100644 --- a/src/Microsoft.SqlTools.Authentication/Authenticator.cs +++ b/src/Microsoft.SqlTools.Authentication/Authenticator.cs @@ -5,7 +5,6 @@ using System.Collections.Concurrent; using Microsoft.Identity.Client; -using Microsoft.Identity.Client.Extensions.Msal; using Microsoft.SqlTools.Authentication.Utility; using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger; @@ -14,30 +13,20 @@ namespace Microsoft.SqlTools.Authentication /// /// Provides APIs to acquire access token using MSAL.NET v4 with provided . /// - public class Authenticator + public class Authenticator: IAuthenticator { - private string applicationClientId; - private string applicationName; - private string cacheFolderPath; - private string cacheFileName; - private MsalCacheHelper cacheHelper; + private AuthenticatorConfiguration configuration; + + private MsalEncryptedCacheHelper msalEncryptedCacheHelper; + private static ConcurrentDictionary PublicClientAppMap = new ConcurrentDictionary(); #region Public APIs - public Authenticator(string appClientId, string appName, string cacheFolderPath, string cacheFileName) + public Authenticator(AuthenticatorConfiguration configuration, MsalEncryptedCacheHelper.IvKeyReadCallback callback) { - this.applicationClientId = appClientId; - this.applicationName = appName; - this.cacheFolderPath = cacheFolderPath; - this.cacheFileName = cacheFileName; - - // Storage creation properties are used to enable file system caching with MSAL - var storageCreationProperties = new StorageCreationPropertiesBuilder(this.cacheFileName, this.cacheFolderPath) - .WithUnprotectedFile().Build(); - - // This hooks up the cross-platform cache into MSAL - this.cacheHelper = MsalCacheHelper.CreateAsync(storageCreationProperties).ConfigureAwait(false).GetAwaiter().GetResult(); + this.configuration = configuration; + this.msalEncryptedCacheHelper = new(configuration, callback); } /// @@ -139,16 +128,16 @@ namespace Microsoft.SqlTools.Authentication if (!PublicClientAppMap.TryGetValue(authorityUrl, out IPublicClientApplication? clientApplicationInstance)) { clientApplicationInstance = CreatePublicClientAppInstance(authority, audience); - this.cacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache); + this.msalEncryptedCacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache); PublicClientAppMap.TryAdd(authorityUrl, clientApplicationInstance); } return clientApplicationInstance; } private IPublicClientApplication CreatePublicClientAppInstance(string authority, string audience) => - PublicClientApplicationBuilder.Create(this.applicationClientId) + PublicClientApplicationBuilder.Create(this.configuration.AppClientId) .WithAuthority(authority, audience) - .WithClientName(this.applicationName) + .WithClientName(this.configuration.AppName) .WithLogging(Utils.MSALLogCallback) .WithDefaultRedirectUri() .Build(); diff --git a/src/Microsoft.SqlTools.Authentication/IAuthenticator.cs b/src/Microsoft.SqlTools.Authentication/IAuthenticator.cs new file mode 100644 index 00000000..3be5441a --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/IAuthenticator.cs @@ -0,0 +1,27 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.Authentication +{ + public interface IAuthenticator + { + /// + /// Acquires access token synchronously. + /// + /// Authentication parameters to be used for access token request. + /// Cancellation token. + /// Access Token with expiry date + public AccessToken? GetToken(AuthenticationParams @params, CancellationToken cancellationToken); + + /// + /// Acquires access token asynchronously. + /// + /// Authentication parameters to be used for access token request. + /// Cancellation token. + /// Access Token with expiry date + /// + public Task GetTokenAsync(AuthenticationParams @params, CancellationToken cancellationToken); + } +} diff --git a/src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs b/src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs new file mode 100644 index 00000000..5371a90b --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs @@ -0,0 +1,261 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Text; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Extensions.Msal; +using Logger = Microsoft.SqlTools.Utility.Logger; + +namespace Microsoft.SqlTools.Authentication.Utility +{ + /// + /// This class provides capability to register MSAL Token cache and uses the beforeCacheAccess and afterCacheAccess callbacks + /// to read and write cache to file system. This is done as cache encryption/decryption algorithm is shared between NodeJS and .NET. + /// Because, we cannot use msal-node-extensions in NodeJS, we also cannot use MSAL Extensions Dotnet NuGet package. + /// Ref https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-node-migration#enable-token-caching + /// In future we should use msal extensions to not have to maintain encryption logic in our applications, and also introduce support for + /// token storage options in system keychain/credential store. + /// However - as of now msal-node-extensions does not come with pre-compiled native libraries that causes runtime issues + /// Ref https://github.com/AzureAD/microsoft-authentication-library-for-js/issues/3332 + /// + public class MsalEncryptedCacheHelper + { + /// + /// Callback delegate to be implemented by Services in Service Host, where authentication is performed. e.g. Connection Service. + /// This delegate will be called to retrieve key and IV data if found absent or during instantiation. + /// + /// (out) Key used for encryption/decryption + /// (out) IV used for encryption/decryption + public delegate void IvKeyReadCallback(out string key, out string iv); + + /// + /// Lock objects for serialization + /// + private readonly object _lockObject = new object(); + private CrossPlatLock? _cacheLock = null; + + private AuthenticatorConfiguration _config; + private StorageCreationProperties _storageCreationProperties; + private IvKeyReadCallback _ivKeyReadCallback; + + private byte[]? _iv; + private byte[]? _key; + + /// + /// Storage that handles the storing of the MSAL cache file on disk. + /// + private Storage _cacheStorage { get; } + + #region Public Methods + + /// + /// Instantiates cache encryption helper instance. + /// + /// Configuration containing cache location and name. + /// Delegate callback to retrieve IV and Key from Credential Store when needed. + public MsalEncryptedCacheHelper(AuthenticatorConfiguration config, IvKeyReadCallback callback) + { + this._config = config; + + this._storageCreationProperties = new StorageCreationPropertiesBuilder(config.CacheFileName, config.CacheFolderPath) + .WithCacheChangedEvent(config.AppClientId) + .WithUnprotectedFile().Build(); + + this._cacheStorage = Storage.Create(_storageCreationProperties, Logger.TraceSource); + + this._ivKeyReadCallback = callback; + + this.fillIvKeyIfNeeded(); + } + + /// + /// Registers before and after access methods that are fired on cache access. + /// + /// Access token cache from MSAL.NET + /// When token cache is not provided. + public void RegisterCache(ITokenCache tokenCache) + { + if (tokenCache == null) + { + throw new ArgumentNullException(nameof(tokenCache)); + } + + Logger.Information($"Registering MSAL token cache with encrypted file storage"); + + // If the token cache was already registered, this operation does nothing + tokenCache.SetBeforeAccess(BeforeAccessNotification); + tokenCache.SetAfterAccess(AfterAccessNotification); + } + + #endregion + + #region Private Methods + + private void fillIvKeyIfNeeded() + { + if (this._key == null || this._iv == null) + { + this._ivKeyReadCallback(out string key, out string iv); + + if (key != null) + { + this._key = Encoding.Unicode.GetBytes(key); + } + + if (iv != null) + { + this._iv = Encoding.Unicode.GetBytes(iv); + } + + Logger.Verbose($"Received IV and Key from callback"); + } + } + + /// + /// Triggered after cache is accessed, provides updated cache data that + /// needs to be updated in File Storage. We encrypt cache data here and store it in file system. + /// + /// Access token cache notification arguments. + private void AfterAccessNotification(TokenCacheNotificationArgs args) + { + try + { + Logger.Verbose($"After access"); + byte[]? data = null; + // if the access operation resulted in a cache update + if (args.HasStateChanged) + { + Logger.Verbose($"After access, cache in memory HasChanged"); + try + { + data = args.TokenCache.SerializeMsalV3(); + } + catch (Exception e) + { + Logger.Error($"An exception was encountered while serializing the {nameof(MsalCacheHelper)} : {e}"); + Logger.Error($"No data found in the store, clearing the cache in memory."); + + // The cache is corrupt clear it out + this._cacheStorage.Clear(ignoreExceptions: true); + } + + if (data != null) + { + Logger.Verbose($"Serializing '{data.Length}' bytes"); + + try + { + fillIvKeyIfNeeded(); + var encryptedData = EncryptionUtils.AesEncrypt(data, this._key!, this._iv!); + File.WriteAllText(this._storageCreationProperties.CacheFileName, Convert.ToBase64String(encryptedData)); + } + catch (Exception e) + { + Logger.Error($"Could not write the token cache. Ignoring. {e.Message}"); + } + } + else + { + Logger.Verbose($"No data read from Token Cache"); + } + } + } + finally + { + ReleaseFileLock(); + } + } + + /// + /// Triggered before cache is accessed, we update with data from file storage. + /// Cache file is decrypted and cache data is synced with MSAL.NET memory token cache. + /// + /// Access token cache notification arguments. + private void BeforeAccessNotification(TokenCacheNotificationArgs args) + { + Logger.Verbose($"Before cache access, acquiring lock for token cache"); + + // We have two nested locks here. We need to maintain a clear ordering to avoid deadlocks. + // This is critical to prevent cache corruption and only 1 process accesses cache file at a time. + // 1. Use the CrossPlatLock which is respected by all processes and is used around all cache accesses. + // This lock (using lockfile) is also shared with NodeJS application. + // 2. Use _lockObject which is used in UnregisterCache, and is needed for all accesses of _registeredCaches. + this._cacheLock = CreateCrossPlatLock(_storageCreationProperties); + + Logger.Verbose($"Before access, the store has changed"); + + byte[]? cachedStoreData = null; + byte[]? decryptedData = null; + + try + { + var text = File.ReadAllText(_storageCreationProperties.CacheFilePath); + if (text != null) + { + cachedStoreData = Convert.FromBase64String(text); + fillIvKeyIfNeeded(); + decryptedData = EncryptionUtils.AesDecrypt(cachedStoreData, this._key!, this._iv!); + } + else + { + Logger.Information($"Token cache not received. Ignoring."); + } + } + catch (Exception ex) + { + Logger.Error($"Could not read the token cache. Ignoring. Exception: {ex}"); + return; + + } + Logger.Verbose($"Read '{cachedStoreData?.Length}' bytes from storage"); + + if (decryptedData != null) + { + lock (_lockObject) + { + try + { + Logger.Verbose($"Deserializing the store"); + args.TokenCache.DeserializeMsalV3(decryptedData, shouldClearExistingCache: true); + } + catch (Exception e) + { + Logger.Error($"An exception was encountered while deserializing the {nameof(MsalCacheHelper)} : {e}"); + Logger.Error($"No data found in the store, clearing the cache in memory."); + + // Clear the memory cache without taking the lock over again + this._cacheStorage.Clear(ignoreExceptions: true); + } + } + } + } + + /// + /// Gets a new instance of a lock for synchronizing against a cache made with the same creation properties. + /// + private static CrossPlatLock CreateCrossPlatLock(StorageCreationProperties storageCreationProperties) + { + return new CrossPlatLock( + storageCreationProperties.CacheFilePath + ".lockfile", + storageCreationProperties.LockRetryDelay, + storageCreationProperties.LockRetryCount); + } + + /// + /// Releases file lock by disposing it. + /// + private void ReleaseFileLock() + { + // Get a local copy and call null before disposing because when the lock is disposed the next thread will replace CacheLock with its instance, + // therefore we do not want to null out CacheLock after dispose since this may orphan a CacheLock. + var localDispose = this._cacheLock; + this._cacheLock = null; + localDispose?.Dispose(); + Logger.Information($"Released local lock"); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs index 1f57ed33..37ea2e3c 100644 --- a/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs +++ b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs @@ -4,8 +4,6 @@ // using Microsoft.Data.SqlClient; -using Microsoft.SqlTools.Authentication.Utility; -using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.Authentication.Sql { @@ -17,12 +15,9 @@ namespace Microsoft.SqlTools.Authentication.Sql /// public class AuthenticationProvider : SqlAuthenticationProvider { - private const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255"; - private const string AzureTokenFolder = "Azure Accounts"; - private const string MsalCacheName = "azureTokenCacheMsal-azure_publicCloud"; private const string s_defaultScopeSuffix = "/.default"; - private Authenticator authenticator; + private IAuthenticator authenticator; /// /// Instantiates AuthenticationProvider to be used for AAD authentication with MSAL.NET and MSAL.js co-ordinated. @@ -30,22 +25,9 @@ namespace Microsoft.SqlTools.Authentication.Sql /// Application Name that identifies user folder path location for reading/writing to shared cache. /// Application Path directory where application cache folder is present. /// Callback that handles AAD authentication when user interaction is needed. - public AuthenticationProvider(string applicationName, string applicationPath) + public AuthenticationProvider(IAuthenticator authenticator) { - if (string.IsNullOrEmpty(applicationName)) - { - applicationName = nameof(SqlTools); - Logger.Warning($"Application Name not received with command options, using default application name as: {applicationName}"); - } - - if (string.IsNullOrEmpty(applicationPath)) - { - applicationPath = Utils.BuildAppDirectoryPath(); - Logger.Warning($"Application Path not received with command options, using default application path as: {applicationPath}"); - } - - var cachePath = Path.Combine(applicationPath, applicationName, AzureTokenFolder); - this.authenticator = new Authenticator(ApplicationClientId, applicationName, cachePath, MsalCacheName); + this.authenticator = authenticator; } /// diff --git a/src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs b/src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs new file mode 100644 index 00000000..9173ec62 --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs @@ -0,0 +1,41 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.Authentication.Utility +{ + /// + /// Configuration used by to perform AAD authentication using MSAL.NET + /// + public class AuthenticatorConfiguration + { + /// + /// Application Client ID to be used. + /// + public string AppClientId { get; set; } + + /// + /// Application name used for public client application instantiation. + /// + public string AppName { get; set; } + + /// + /// Cache folder path, to be used by MSAL.NET to store encrypted token cache. + /// + public string CacheFolderPath { get; set; } + + /// + /// File name to be used for token storage. + /// Full path of file: \ + /// + public string CacheFileName { get; set; } + + public AuthenticatorConfiguration(string appClientId, string appName, string cacheFolderPath, string cacheFileName) { + AppClientId = appClientId; + AppName = appName; + CacheFolderPath = cacheFolderPath; + CacheFileName = cacheFileName; + } + } +} diff --git a/src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs b/src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs new file mode 100644 index 00000000..591100bf --- /dev/null +++ b/src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs @@ -0,0 +1,76 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Security.Cryptography; + +namespace Microsoft.SqlTools.Authentication.Utility +{ + public static class EncryptionUtils + { + /// + /// Encrypts provided byte array with 'aes-256-cbc' algorithm. + /// + /// Plain text data + /// Encryption Key + /// Encryption IV + /// Encrypted data in bytes + /// When arguments are null or empty. + public static byte[] AesEncrypt(byte[] plainText, byte[] key, byte[] iv) + { + // Check arguments. + if (plainText == null || plainText.Length <= 0) + { + throw new ArgumentNullException(nameof(plainText)); + } + + using var aes = CreateAes(key, iv); + using var encryptor = aes.CreateEncryptor(); + return encryptor.TransformFinalBlock(plainText, 0, plainText.Length); + } + + /// + /// Decrypts provided byte array with 'aes-256-cbc' algorithm. + /// + /// Encrypted data + /// Encryption Key + /// Encryption IV + /// Plain text data in bytes + /// When arguments are null or empty. + public static byte[] AesDecrypt(byte[] cipherText, byte[] key, byte[] iv) + { + // Check arguments. + if (cipherText == null || cipherText.Length <= 0) + { + throw new ArgumentNullException(nameof(cipherText)); + } + + using var aes = CreateAes(key, iv); + using var decryptor = aes.CreateDecryptor(); + return decryptor.TransformFinalBlock(cipherText, 0, cipherText.Length); + } + + private static Aes CreateAes(byte[] key, byte[] iv) + { + // Check arguments. + if (key == null || key.Length <= 0) + { + throw new ArgumentNullException(nameof(key)); + } + if (iv == null || iv.Length <= 0) + { + throw new ArgumentNullException(nameof(iv)); + } + + var aes = Aes.Create(); + aes.Mode = CipherMode.CBC; + aes.Padding = PaddingMode.PKCS7; + aes.KeySize = 256; + aes.BlockSize = 128; + aes.Key = key; + aes.IV = iv; + return aes; + } + } +} diff --git a/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs index 939e36da..5d80f5e7 100644 --- a/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs +++ b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs @@ -4,7 +4,6 @@ // using System.Net.Mail; -using System.Runtime.InteropServices; using Microsoft.Identity.Client; using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger; @@ -30,59 +29,6 @@ namespace Microsoft.SqlTools.Authentication.Utility } } - /// - /// Builds directory path based on environment settings. - /// - /// Application directory path - /// When called on unsupported platform. - public static string BuildAppDirectoryPath() - { - var homedir = Environment.GetFolderPath(Environment.SpecialFolder.Personal); - - // Windows - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - var appData = Environment.GetEnvironmentVariable("APPDATA"); - var userProfile = Environment.GetEnvironmentVariable("USERPROFILE"); - if (appData != null) - { - return appData; - } - else if (userProfile != null) - { - return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming"); - } - else - { - throw new Exception("Not able to find APPDATA or USERPROFILE"); - } - } - - // Mac - else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - return string.Join(homedir, "Library", "Application Support"); - } - - // Linux - else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME"); - if (xdgConfigHome != null) - { - return xdgConfigHome; - } - else - { - return string.Join(homedir, ".config"); - } - } - else - { - throw new Exception("Platform not supported"); - } - } - /// /// Log callback handler used for MSAL Client applications. /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index c44ea349..e946123e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -26,6 +26,11 @@ using Microsoft.SqlTools.Utility; using static Microsoft.SqlTools.Shared.Utility.Constants; using System.Diagnostics; using Microsoft.SqlTools.Authentication.Sql; +using Microsoft.SqlTools.Credentials; +using Microsoft.SqlTools.Credentials.Contracts; +using Microsoft.SqlTools.Authentication; +using Microsoft.SqlTools.Shared.Utility; +using System.IO; namespace Microsoft.SqlTools.ServiceLayer.Connection { @@ -37,6 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public const string AdminConnectionPrefix = "ADMIN:"; internal const string PasswordPlaceholder = "******"; private const string SqlAzureEdition = "SQL Azure"; + public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens public const int MaxServerlessReconnectTries = 5; // Max number of tries to wait for a serverless database to start up when its paused before giving up. @@ -58,6 +64,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// public static ConnectionService Instance => instance.Value; + /// + /// The authenticator instance for AAD MFA authentication needs. + /// + private IAuthenticator authenticator; + /// /// The SQL connection factory object /// @@ -1072,11 +1083,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions) { this.ServiceHost = serviceHost; - + if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider) { + // Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication. - var provider = new AuthenticationProvider(commandOptions.ApplicationName, commandOptions.ApplicationPath); + var provider = new AuthenticationProvider(GetAuthenticator(commandOptions)); SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider); this.EnableSqlAuthenticationProvider = true; @@ -1134,6 +1146,69 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } + private IAuthenticator GetAuthenticator(CommandOptions commandOptions) + { + var applicationName = commandOptions.ApplicationName; + if (string.IsNullOrEmpty(applicationName)) + { + applicationName = nameof(SqlTools); + Logger.Warning($"Application Name not received with command options, using default application name as: {applicationName}"); + } + + var applicationPath = commandOptions.ApplicationPath; + if (string.IsNullOrEmpty(applicationPath)) + { + applicationPath = Utils.BuildAppDirectoryPath(); + Logger.Warning($"Application Path not received with command options, using default application path as: {applicationPath}"); + } + + var cachePath = Path.Combine(applicationPath, applicationName, AzureTokenFolder); + return new Authenticator(new (ApplicationClientId, applicationName, cachePath, MsalCacheName), ReadCacheIvKey); + } + + private void ReadCacheIvKey(out string? key, out string? iv) + { + Logger.Verbose("Reading Cached IV and Key from OS credential store."); + + iv = null; + key = null; + try + { + // Read Cached Iv for MSAL cache (as Unicode) + Credential ivCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-iv")); + if (!string.IsNullOrEmpty(ivCred.Password)) + { + iv = ivCred.Password; + } + else + { + throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-iv"); + } + } + catch (Exception ex) + { + Logger.Error(ex); + } + + try + { + // Read Cached Key for MSAL cache (as Unicode) + Credential keyCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-key")); + if (!string.IsNullOrEmpty(keyCred.Password)) + { + key = keyCred.Password; + } + else + { + throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-key"); + } + } + catch (Exception ex) + { + Logger.Error(ex); + } + } + private void RunConnectRequestHandlerTask(ConnectParams connectParams) { // create a task to connect asynchronously so that other requests are not blocked in the meantime diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs index 93e2b0f2..ed9189bd 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs @@ -137,7 +137,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace serviceHost.SetEventHandler(DidChangeConfigurationNotification.Type, HandleDidChangeConfigurationNotification); // Register an initialization handler that sets the workspace path - serviceHost.RegisterInitializeTask((parameters, contect) => + serviceHost.RegisterInitializeTask((parameters, context) => { Logger.Write(TraceEventType.Verbose, "Initializing workspace service"); diff --git a/src/Microsoft.SqlTools.Shared/Utility/Constants.cs b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs index d3a9d4e6..b567e705 100644 --- a/src/Microsoft.SqlTools.Shared/Utility/Constants.cs +++ b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs @@ -14,5 +14,11 @@ namespace Microsoft.SqlTools.Shared.Utility public const string dstsAuth = "dstsAuth"; public const string ActiveDirectoryInteractive = "ActiveDirectoryInteractive"; public const string ActiveDirectoryPassword = "ActiveDirectoryPassword"; + + // Azure authentication (MSAL) constants + public const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255"; + public const string AzureTokenFolder = "Azure Accounts"; + public const string AzureAccountProviderCredentials = "azureAccountProviderCredentials"; + public const string MsalCacheName = "accessTokenCache"; } } diff --git a/src/Microsoft.SqlTools.Shared/Utility/Utils.cs b/src/Microsoft.SqlTools.Shared/Utility/Utils.cs new file mode 100644 index 00000000..148f2e0f --- /dev/null +++ b/src/Microsoft.SqlTools.Shared/Utility/Utils.cs @@ -0,0 +1,65 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Runtime.InteropServices; + +namespace Microsoft.SqlTools.Shared.Utility +{ + public static class Utils + { + /// + /// Builds directory path based on environment settings. + /// + /// Application directory path + /// When called on unsupported platform. + public static string BuildAppDirectoryPath() + { + var homedir = Environment.GetFolderPath(Environment.SpecialFolder.Personal); + + // Windows + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var appData = Environment.GetEnvironmentVariable("APPDATA"); + var userProfile = Environment.GetEnvironmentVariable("USERPROFILE"); + if (appData != null) + { + return appData; + } + else if (userProfile != null) + { + return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming"); + } + else + { + throw new Exception("Not able to find APPDATA or USERPROFILE"); + } + } + + // Mac + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return string.Join(homedir, "Library", "Application Support"); + } + + // Linux + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME"); + if (xdgConfigHome != null) + { + return xdgConfigHome; + } + else + { + return string.Join(homedir, ".config"); + } + } + else + { + throw new Exception("Platform not supported"); + } + } + } +}