From 4d9cb17c9316f56685c800f4abe8aaa5cd93b626 Mon Sep 17 00:00:00 2001
From: Cheena Malhotra <13396919+cheenamalhotra@users.noreply.github.com>
Date: Thu, 16 Mar 2023 11:56:35 -0700
Subject: [PATCH] MSAL encrypted file system cache (#1945)
---
.../Authenticator.cs | 33 +--
.../IAuthenticator.cs | 27 ++
.../MSALEncryptedCacheHelper.cs | 261 ++++++++++++++++++
.../Sql/AuthenticationProvider.cs | 24 +-
.../Utility/AuthenticatorConfiguration.cs | 41 +++
.../Utility/Encryption/EncryptionUtils.cs | 76 +++++
.../Utility/Utils.cs | 54 ----
.../Connection/ConnectionService.cs | 79 +++++-
.../Workspace/WorkspaceService.cs | 2 +-
.../Utility/Constants.cs | 6 +
.../Utility/Utils.cs | 65 +++++
11 files changed, 568 insertions(+), 100 deletions(-)
create mode 100644 src/Microsoft.SqlTools.Authentication/IAuthenticator.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs
create mode 100644 src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs
create mode 100644 src/Microsoft.SqlTools.Shared/Utility/Utils.cs
diff --git a/src/Microsoft.SqlTools.Authentication/Authenticator.cs b/src/Microsoft.SqlTools.Authentication/Authenticator.cs
index d149a85f..e3f279a0 100644
--- a/src/Microsoft.SqlTools.Authentication/Authenticator.cs
+++ b/src/Microsoft.SqlTools.Authentication/Authenticator.cs
@@ -5,7 +5,6 @@
using System.Collections.Concurrent;
using Microsoft.Identity.Client;
-using Microsoft.Identity.Client.Extensions.Msal;
using Microsoft.SqlTools.Authentication.Utility;
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
@@ -14,30 +13,20 @@ namespace Microsoft.SqlTools.Authentication
///
/// Provides APIs to acquire access token using MSAL.NET v4 with provided .
///
- public class Authenticator
+ public class Authenticator: IAuthenticator
{
- private string applicationClientId;
- private string applicationName;
- private string cacheFolderPath;
- private string cacheFileName;
- private MsalCacheHelper cacheHelper;
+ private AuthenticatorConfiguration configuration;
+
+ private MsalEncryptedCacheHelper msalEncryptedCacheHelper;
+
private static ConcurrentDictionary PublicClientAppMap
= new ConcurrentDictionary();
#region Public APIs
- public Authenticator(string appClientId, string appName, string cacheFolderPath, string cacheFileName)
+ public Authenticator(AuthenticatorConfiguration configuration, MsalEncryptedCacheHelper.IvKeyReadCallback callback)
{
- this.applicationClientId = appClientId;
- this.applicationName = appName;
- this.cacheFolderPath = cacheFolderPath;
- this.cacheFileName = cacheFileName;
-
- // Storage creation properties are used to enable file system caching with MSAL
- var storageCreationProperties = new StorageCreationPropertiesBuilder(this.cacheFileName, this.cacheFolderPath)
- .WithUnprotectedFile().Build();
-
- // This hooks up the cross-platform cache into MSAL
- this.cacheHelper = MsalCacheHelper.CreateAsync(storageCreationProperties).ConfigureAwait(false).GetAwaiter().GetResult();
+ this.configuration = configuration;
+ this.msalEncryptedCacheHelper = new(configuration, callback);
}
///
@@ -139,16 +128,16 @@ namespace Microsoft.SqlTools.Authentication
if (!PublicClientAppMap.TryGetValue(authorityUrl, out IPublicClientApplication? clientApplicationInstance))
{
clientApplicationInstance = CreatePublicClientAppInstance(authority, audience);
- this.cacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache);
+ this.msalEncryptedCacheHelper.RegisterCache(clientApplicationInstance.UserTokenCache);
PublicClientAppMap.TryAdd(authorityUrl, clientApplicationInstance);
}
return clientApplicationInstance;
}
private IPublicClientApplication CreatePublicClientAppInstance(string authority, string audience) =>
- PublicClientApplicationBuilder.Create(this.applicationClientId)
+ PublicClientApplicationBuilder.Create(this.configuration.AppClientId)
.WithAuthority(authority, audience)
- .WithClientName(this.applicationName)
+ .WithClientName(this.configuration.AppName)
.WithLogging(Utils.MSALLogCallback)
.WithDefaultRedirectUri()
.Build();
diff --git a/src/Microsoft.SqlTools.Authentication/IAuthenticator.cs b/src/Microsoft.SqlTools.Authentication/IAuthenticator.cs
new file mode 100644
index 00000000..3be5441a
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/IAuthenticator.cs
@@ -0,0 +1,27 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.Authentication
+{
+ public interface IAuthenticator
+ {
+ ///
+ /// Acquires access token synchronously.
+ ///
+ /// Authentication parameters to be used for access token request.
+ /// Cancellation token.
+ /// Access Token with expiry date
+ public AccessToken? GetToken(AuthenticationParams @params, CancellationToken cancellationToken);
+
+ ///
+ /// Acquires access token asynchronously.
+ ///
+ /// Authentication parameters to be used for access token request.
+ /// Cancellation token.
+ /// Access Token with expiry date
+ ///
+ public Task GetTokenAsync(AuthenticationParams @params, CancellationToken cancellationToken);
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs b/src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs
new file mode 100644
index 00000000..5371a90b
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/MSALEncryptedCacheHelper.cs
@@ -0,0 +1,261 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Text;
+using Microsoft.Identity.Client;
+using Microsoft.Identity.Client.Extensions.Msal;
+using Logger = Microsoft.SqlTools.Utility.Logger;
+
+namespace Microsoft.SqlTools.Authentication.Utility
+{
+ ///
+ /// This class provides capability to register MSAL Token cache and uses the beforeCacheAccess and afterCacheAccess callbacks
+ /// to read and write cache to file system. This is done as cache encryption/decryption algorithm is shared between NodeJS and .NET.
+ /// Because, we cannot use msal-node-extensions in NodeJS, we also cannot use MSAL Extensions Dotnet NuGet package.
+ /// Ref https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-node-migration#enable-token-caching
+ /// In future we should use msal extensions to not have to maintain encryption logic in our applications, and also introduce support for
+ /// token storage options in system keychain/credential store.
+ /// However - as of now msal-node-extensions does not come with pre-compiled native libraries that causes runtime issues
+ /// Ref https://github.com/AzureAD/microsoft-authentication-library-for-js/issues/3332
+ ///
+ public class MsalEncryptedCacheHelper
+ {
+ ///
+ /// Callback delegate to be implemented by Services in Service Host, where authentication is performed. e.g. Connection Service.
+ /// This delegate will be called to retrieve key and IV data if found absent or during instantiation.
+ ///
+ /// (out) Key used for encryption/decryption
+ /// (out) IV used for encryption/decryption
+ public delegate void IvKeyReadCallback(out string key, out string iv);
+
+ ///
+ /// Lock objects for serialization
+ ///
+ private readonly object _lockObject = new object();
+ private CrossPlatLock? _cacheLock = null;
+
+ private AuthenticatorConfiguration _config;
+ private StorageCreationProperties _storageCreationProperties;
+ private IvKeyReadCallback _ivKeyReadCallback;
+
+ private byte[]? _iv;
+ private byte[]? _key;
+
+ ///
+ /// Storage that handles the storing of the MSAL cache file on disk.
+ ///
+ private Storage _cacheStorage { get; }
+
+ #region Public Methods
+
+ ///
+ /// Instantiates cache encryption helper instance.
+ ///
+ /// Configuration containing cache location and name.
+ /// Delegate callback to retrieve IV and Key from Credential Store when needed.
+ public MsalEncryptedCacheHelper(AuthenticatorConfiguration config, IvKeyReadCallback callback)
+ {
+ this._config = config;
+
+ this._storageCreationProperties = new StorageCreationPropertiesBuilder(config.CacheFileName, config.CacheFolderPath)
+ .WithCacheChangedEvent(config.AppClientId)
+ .WithUnprotectedFile().Build();
+
+ this._cacheStorage = Storage.Create(_storageCreationProperties, Logger.TraceSource);
+
+ this._ivKeyReadCallback = callback;
+
+ this.fillIvKeyIfNeeded();
+ }
+
+ ///
+ /// Registers before and after access methods that are fired on cache access.
+ ///
+ /// Access token cache from MSAL.NET
+ /// When token cache is not provided.
+ public void RegisterCache(ITokenCache tokenCache)
+ {
+ if (tokenCache == null)
+ {
+ throw new ArgumentNullException(nameof(tokenCache));
+ }
+
+ Logger.Information($"Registering MSAL token cache with encrypted file storage");
+
+ // If the token cache was already registered, this operation does nothing
+ tokenCache.SetBeforeAccess(BeforeAccessNotification);
+ tokenCache.SetAfterAccess(AfterAccessNotification);
+ }
+
+ #endregion
+
+ #region Private Methods
+
+ private void fillIvKeyIfNeeded()
+ {
+ if (this._key == null || this._iv == null)
+ {
+ this._ivKeyReadCallback(out string key, out string iv);
+
+ if (key != null)
+ {
+ this._key = Encoding.Unicode.GetBytes(key);
+ }
+
+ if (iv != null)
+ {
+ this._iv = Encoding.Unicode.GetBytes(iv);
+ }
+
+ Logger.Verbose($"Received IV and Key from callback");
+ }
+ }
+
+ ///
+ /// Triggered after cache is accessed, provides updated cache data that
+ /// needs to be updated in File Storage. We encrypt cache data here and store it in file system.
+ ///
+ /// Access token cache notification arguments.
+ private void AfterAccessNotification(TokenCacheNotificationArgs args)
+ {
+ try
+ {
+ Logger.Verbose($"After access");
+ byte[]? data = null;
+ // if the access operation resulted in a cache update
+ if (args.HasStateChanged)
+ {
+ Logger.Verbose($"After access, cache in memory HasChanged");
+ try
+ {
+ data = args.TokenCache.SerializeMsalV3();
+ }
+ catch (Exception e)
+ {
+ Logger.Error($"An exception was encountered while serializing the {nameof(MsalCacheHelper)} : {e}");
+ Logger.Error($"No data found in the store, clearing the cache in memory.");
+
+ // The cache is corrupt clear it out
+ this._cacheStorage.Clear(ignoreExceptions: true);
+ }
+
+ if (data != null)
+ {
+ Logger.Verbose($"Serializing '{data.Length}' bytes");
+
+ try
+ {
+ fillIvKeyIfNeeded();
+ var encryptedData = EncryptionUtils.AesEncrypt(data, this._key!, this._iv!);
+ File.WriteAllText(this._storageCreationProperties.CacheFileName, Convert.ToBase64String(encryptedData));
+ }
+ catch (Exception e)
+ {
+ Logger.Error($"Could not write the token cache. Ignoring. {e.Message}");
+ }
+ }
+ else
+ {
+ Logger.Verbose($"No data read from Token Cache");
+ }
+ }
+ }
+ finally
+ {
+ ReleaseFileLock();
+ }
+ }
+
+ ///
+ /// Triggered before cache is accessed, we update with data from file storage.
+ /// Cache file is decrypted and cache data is synced with MSAL.NET memory token cache.
+ ///
+ /// Access token cache notification arguments.
+ private void BeforeAccessNotification(TokenCacheNotificationArgs args)
+ {
+ Logger.Verbose($"Before cache access, acquiring lock for token cache");
+
+ // We have two nested locks here. We need to maintain a clear ordering to avoid deadlocks.
+ // This is critical to prevent cache corruption and only 1 process accesses cache file at a time.
+ // 1. Use the CrossPlatLock which is respected by all processes and is used around all cache accesses.
+ // This lock (using lockfile) is also shared with NodeJS application.
+ // 2. Use _lockObject which is used in UnregisterCache, and is needed for all accesses of _registeredCaches.
+ this._cacheLock = CreateCrossPlatLock(_storageCreationProperties);
+
+ Logger.Verbose($"Before access, the store has changed");
+
+ byte[]? cachedStoreData = null;
+ byte[]? decryptedData = null;
+
+ try
+ {
+ var text = File.ReadAllText(_storageCreationProperties.CacheFilePath);
+ if (text != null)
+ {
+ cachedStoreData = Convert.FromBase64String(text);
+ fillIvKeyIfNeeded();
+ decryptedData = EncryptionUtils.AesDecrypt(cachedStoreData, this._key!, this._iv!);
+ }
+ else
+ {
+ Logger.Information($"Token cache not received. Ignoring.");
+ }
+ }
+ catch (Exception ex)
+ {
+ Logger.Error($"Could not read the token cache. Ignoring. Exception: {ex}");
+ return;
+
+ }
+ Logger.Verbose($"Read '{cachedStoreData?.Length}' bytes from storage");
+
+ if (decryptedData != null)
+ {
+ lock (_lockObject)
+ {
+ try
+ {
+ Logger.Verbose($"Deserializing the store");
+ args.TokenCache.DeserializeMsalV3(decryptedData, shouldClearExistingCache: true);
+ }
+ catch (Exception e)
+ {
+ Logger.Error($"An exception was encountered while deserializing the {nameof(MsalCacheHelper)} : {e}");
+ Logger.Error($"No data found in the store, clearing the cache in memory.");
+
+ // Clear the memory cache without taking the lock over again
+ this._cacheStorage.Clear(ignoreExceptions: true);
+ }
+ }
+ }
+ }
+
+ ///
+ /// Gets a new instance of a lock for synchronizing against a cache made with the same creation properties.
+ ///
+ private static CrossPlatLock CreateCrossPlatLock(StorageCreationProperties storageCreationProperties)
+ {
+ return new CrossPlatLock(
+ storageCreationProperties.CacheFilePath + ".lockfile",
+ storageCreationProperties.LockRetryDelay,
+ storageCreationProperties.LockRetryCount);
+ }
+
+ ///
+ /// Releases file lock by disposing it.
+ ///
+ private void ReleaseFileLock()
+ {
+ // Get a local copy and call null before disposing because when the lock is disposed the next thread will replace CacheLock with its instance,
+ // therefore we do not want to null out CacheLock after dispose since this may orphan a CacheLock.
+ var localDispose = this._cacheLock;
+ this._cacheLock = null;
+ localDispose?.Dispose();
+ Logger.Information($"Released local lock");
+ }
+
+ #endregion
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs
index 1f57ed33..37ea2e3c 100644
--- a/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs
+++ b/src/Microsoft.SqlTools.Authentication/Sql/AuthenticationProvider.cs
@@ -4,8 +4,6 @@
//
using Microsoft.Data.SqlClient;
-using Microsoft.SqlTools.Authentication.Utility;
-using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.Authentication.Sql
{
@@ -17,12 +15,9 @@ namespace Microsoft.SqlTools.Authentication.Sql
///
public class AuthenticationProvider : SqlAuthenticationProvider
{
- private const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255";
- private const string AzureTokenFolder = "Azure Accounts";
- private const string MsalCacheName = "azureTokenCacheMsal-azure_publicCloud";
private const string s_defaultScopeSuffix = "/.default";
- private Authenticator authenticator;
+ private IAuthenticator authenticator;
///
/// Instantiates AuthenticationProvider to be used for AAD authentication with MSAL.NET and MSAL.js co-ordinated.
@@ -30,22 +25,9 @@ namespace Microsoft.SqlTools.Authentication.Sql
/// Application Name that identifies user folder path location for reading/writing to shared cache.
/// Application Path directory where application cache folder is present.
/// Callback that handles AAD authentication when user interaction is needed.
- public AuthenticationProvider(string applicationName, string applicationPath)
+ public AuthenticationProvider(IAuthenticator authenticator)
{
- if (string.IsNullOrEmpty(applicationName))
- {
- applicationName = nameof(SqlTools);
- Logger.Warning($"Application Name not received with command options, using default application name as: {applicationName}");
- }
-
- if (string.IsNullOrEmpty(applicationPath))
- {
- applicationPath = Utils.BuildAppDirectoryPath();
- Logger.Warning($"Application Path not received with command options, using default application path as: {applicationPath}");
- }
-
- var cachePath = Path.Combine(applicationPath, applicationName, AzureTokenFolder);
- this.authenticator = new Authenticator(ApplicationClientId, applicationName, cachePath, MsalCacheName);
+ this.authenticator = authenticator;
}
///
diff --git a/src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs b/src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs
new file mode 100644
index 00000000..9173ec62
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/Utility/AuthenticatorConfiguration.cs
@@ -0,0 +1,41 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.Authentication.Utility
+{
+ ///
+ /// Configuration used by to perform AAD authentication using MSAL.NET
+ ///
+ public class AuthenticatorConfiguration
+ {
+ ///
+ /// Application Client ID to be used.
+ ///
+ public string AppClientId { get; set; }
+
+ ///
+ /// Application name used for public client application instantiation.
+ ///
+ public string AppName { get; set; }
+
+ ///
+ /// Cache folder path, to be used by MSAL.NET to store encrypted token cache.
+ ///
+ public string CacheFolderPath { get; set; }
+
+ ///
+ /// File name to be used for token storage.
+ /// Full path of file: \
+ ///
+ public string CacheFileName { get; set; }
+
+ public AuthenticatorConfiguration(string appClientId, string appName, string cacheFolderPath, string cacheFileName) {
+ AppClientId = appClientId;
+ AppName = appName;
+ CacheFolderPath = cacheFolderPath;
+ CacheFileName = cacheFileName;
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs b/src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs
new file mode 100644
index 00000000..591100bf
--- /dev/null
+++ b/src/Microsoft.SqlTools.Authentication/Utility/Encryption/EncryptionUtils.cs
@@ -0,0 +1,76 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Security.Cryptography;
+
+namespace Microsoft.SqlTools.Authentication.Utility
+{
+ public static class EncryptionUtils
+ {
+ ///
+ /// Encrypts provided byte array with 'aes-256-cbc' algorithm.
+ ///
+ /// Plain text data
+ /// Encryption Key
+ /// Encryption IV
+ /// Encrypted data in bytes
+ /// When arguments are null or empty.
+ public static byte[] AesEncrypt(byte[] plainText, byte[] key, byte[] iv)
+ {
+ // Check arguments.
+ if (plainText == null || plainText.Length <= 0)
+ {
+ throw new ArgumentNullException(nameof(plainText));
+ }
+
+ using var aes = CreateAes(key, iv);
+ using var encryptor = aes.CreateEncryptor();
+ return encryptor.TransformFinalBlock(plainText, 0, plainText.Length);
+ }
+
+ ///
+ /// Decrypts provided byte array with 'aes-256-cbc' algorithm.
+ ///
+ /// Encrypted data
+ /// Encryption Key
+ /// Encryption IV
+ /// Plain text data in bytes
+ /// When arguments are null or empty.
+ public static byte[] AesDecrypt(byte[] cipherText, byte[] key, byte[] iv)
+ {
+ // Check arguments.
+ if (cipherText == null || cipherText.Length <= 0)
+ {
+ throw new ArgumentNullException(nameof(cipherText));
+ }
+
+ using var aes = CreateAes(key, iv);
+ using var decryptor = aes.CreateDecryptor();
+ return decryptor.TransformFinalBlock(cipherText, 0, cipherText.Length);
+ }
+
+ private static Aes CreateAes(byte[] key, byte[] iv)
+ {
+ // Check arguments.
+ if (key == null || key.Length <= 0)
+ {
+ throw new ArgumentNullException(nameof(key));
+ }
+ if (iv == null || iv.Length <= 0)
+ {
+ throw new ArgumentNullException(nameof(iv));
+ }
+
+ var aes = Aes.Create();
+ aes.Mode = CipherMode.CBC;
+ aes.Padding = PaddingMode.PKCS7;
+ aes.KeySize = 256;
+ aes.BlockSize = 128;
+ aes.Key = key;
+ aes.IV = iv;
+ return aes;
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
index 939e36da..5d80f5e7 100644
--- a/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
+++ b/src/Microsoft.SqlTools.Authentication/Utility/Utils.cs
@@ -4,7 +4,6 @@
//
using System.Net.Mail;
-using System.Runtime.InteropServices;
using Microsoft.Identity.Client;
using SqlToolsLogger = Microsoft.SqlTools.Utility.Logger;
@@ -30,59 +29,6 @@ namespace Microsoft.SqlTools.Authentication.Utility
}
}
- ///
- /// Builds directory path based on environment settings.
- ///
- /// Application directory path
- /// When called on unsupported platform.
- public static string BuildAppDirectoryPath()
- {
- var homedir = Environment.GetFolderPath(Environment.SpecialFolder.Personal);
-
- // Windows
- if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
- {
- var appData = Environment.GetEnvironmentVariable("APPDATA");
- var userProfile = Environment.GetEnvironmentVariable("USERPROFILE");
- if (appData != null)
- {
- return appData;
- }
- else if (userProfile != null)
- {
- return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming");
- }
- else
- {
- throw new Exception("Not able to find APPDATA or USERPROFILE");
- }
- }
-
- // Mac
- else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
- {
- return string.Join(homedir, "Library", "Application Support");
- }
-
- // Linux
- else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
- {
- var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME");
- if (xdgConfigHome != null)
- {
- return xdgConfigHome;
- }
- else
- {
- return string.Join(homedir, ".config");
- }
- }
- else
- {
- throw new Exception("Platform not supported");
- }
- }
-
///
/// Log callback handler used for MSAL Client applications.
///
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index c44ea349..e946123e 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -26,6 +26,11 @@ using Microsoft.SqlTools.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using System.Diagnostics;
using Microsoft.SqlTools.Authentication.Sql;
+using Microsoft.SqlTools.Credentials;
+using Microsoft.SqlTools.Credentials.Contracts;
+using Microsoft.SqlTools.Authentication;
+using Microsoft.SqlTools.Shared.Utility;
+using System.IO;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -37,6 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public const string AdminConnectionPrefix = "ADMIN:";
internal const string PasswordPlaceholder = "******";
private const string SqlAzureEdition = "SQL Azure";
+
public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens
public const int MaxServerlessReconnectTries = 5; // Max number of tries to wait for a serverless database to start up when its paused before giving up.
@@ -58,6 +64,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
///
public static ConnectionService Instance => instance.Value;
+ ///
+ /// The authenticator instance for AAD MFA authentication needs.
+ ///
+ private IAuthenticator authenticator;
+
///
/// The SQL connection factory object
///
@@ -1072,11 +1083,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
{
this.ServiceHost = serviceHost;
-
+
if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
{
+
// Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
- var provider = new AuthenticationProvider(commandOptions.ApplicationName, commandOptions.ApplicationPath);
+ var provider = new AuthenticationProvider(GetAuthenticator(commandOptions));
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
this.EnableSqlAuthenticationProvider = true;
@@ -1134,6 +1146,69 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
}
+ private IAuthenticator GetAuthenticator(CommandOptions commandOptions)
+ {
+ var applicationName = commandOptions.ApplicationName;
+ if (string.IsNullOrEmpty(applicationName))
+ {
+ applicationName = nameof(SqlTools);
+ Logger.Warning($"Application Name not received with command options, using default application name as: {applicationName}");
+ }
+
+ var applicationPath = commandOptions.ApplicationPath;
+ if (string.IsNullOrEmpty(applicationPath))
+ {
+ applicationPath = Utils.BuildAppDirectoryPath();
+ Logger.Warning($"Application Path not received with command options, using default application path as: {applicationPath}");
+ }
+
+ var cachePath = Path.Combine(applicationPath, applicationName, AzureTokenFolder);
+ return new Authenticator(new (ApplicationClientId, applicationName, cachePath, MsalCacheName), ReadCacheIvKey);
+ }
+
+ private void ReadCacheIvKey(out string? key, out string? iv)
+ {
+ Logger.Verbose("Reading Cached IV and Key from OS credential store.");
+
+ iv = null;
+ key = null;
+ try
+ {
+ // Read Cached Iv for MSAL cache (as Unicode)
+ Credential ivCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-iv"));
+ if (!string.IsNullOrEmpty(ivCred.Password))
+ {
+ iv = ivCred.Password;
+ }
+ else
+ {
+ throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-iv");
+ }
+ }
+ catch (Exception ex)
+ {
+ Logger.Error(ex);
+ }
+
+ try
+ {
+ // Read Cached Key for MSAL cache (as Unicode)
+ Credential keyCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-key"));
+ if (!string.IsNullOrEmpty(keyCred.Password))
+ {
+ key = keyCred.Password;
+ }
+ else
+ {
+ throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-key");
+ }
+ }
+ catch (Exception ex)
+ {
+ Logger.Error(ex);
+ }
+ }
+
private void RunConnectRequestHandlerTask(ConnectParams connectParams)
{
// create a task to connect asynchronously so that other requests are not blocked in the meantime
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs
index 93e2b0f2..ed9189bd 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/WorkspaceService.cs
@@ -137,7 +137,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace
serviceHost.SetEventHandler(DidChangeConfigurationNotification.Type, HandleDidChangeConfigurationNotification);
// Register an initialization handler that sets the workspace path
- serviceHost.RegisterInitializeTask((parameters, contect) =>
+ serviceHost.RegisterInitializeTask((parameters, context) =>
{
Logger.Write(TraceEventType.Verbose, "Initializing workspace service");
diff --git a/src/Microsoft.SqlTools.Shared/Utility/Constants.cs b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs
index d3a9d4e6..b567e705 100644
--- a/src/Microsoft.SqlTools.Shared/Utility/Constants.cs
+++ b/src/Microsoft.SqlTools.Shared/Utility/Constants.cs
@@ -14,5 +14,11 @@ namespace Microsoft.SqlTools.Shared.Utility
public const string dstsAuth = "dstsAuth";
public const string ActiveDirectoryInteractive = "ActiveDirectoryInteractive";
public const string ActiveDirectoryPassword = "ActiveDirectoryPassword";
+
+ // Azure authentication (MSAL) constants
+ public const string ApplicationClientId = "a69788c6-1d43-44ed-9ca3-b83e194da255";
+ public const string AzureTokenFolder = "Azure Accounts";
+ public const string AzureAccountProviderCredentials = "azureAccountProviderCredentials";
+ public const string MsalCacheName = "accessTokenCache";
}
}
diff --git a/src/Microsoft.SqlTools.Shared/Utility/Utils.cs b/src/Microsoft.SqlTools.Shared/Utility/Utils.cs
new file mode 100644
index 00000000..148f2e0f
--- /dev/null
+++ b/src/Microsoft.SqlTools.Shared/Utility/Utils.cs
@@ -0,0 +1,65 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Runtime.InteropServices;
+
+namespace Microsoft.SqlTools.Shared.Utility
+{
+ public static class Utils
+ {
+ ///
+ /// Builds directory path based on environment settings.
+ ///
+ /// Application directory path
+ /// When called on unsupported platform.
+ public static string BuildAppDirectoryPath()
+ {
+ var homedir = Environment.GetFolderPath(Environment.SpecialFolder.Personal);
+
+ // Windows
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ var appData = Environment.GetEnvironmentVariable("APPDATA");
+ var userProfile = Environment.GetEnvironmentVariable("USERPROFILE");
+ if (appData != null)
+ {
+ return appData;
+ }
+ else if (userProfile != null)
+ {
+ return string.Join(Environment.GetEnvironmentVariable("USERPROFILE"), "AppData", "Roaming");
+ }
+ else
+ {
+ throw new Exception("Not able to find APPDATA or USERPROFILE");
+ }
+ }
+
+ // Mac
+ else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
+ {
+ return string.Join(homedir, "Library", "Application Support");
+ }
+
+ // Linux
+ else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
+ {
+ var xdgConfigHome = Environment.GetEnvironmentVariable("XDG_CONFIG_HOME");
+ if (xdgConfigHome != null)
+ {
+ return xdgConfigHome;
+ }
+ else
+ {
+ return string.Join(homedir, ".config");
+ }
+ }
+ else
+ {
+ throw new Exception("Platform not supported");
+ }
+ }
+ }
+}