MSAL encrypted file system cache (#1945)

This commit is contained in:
Cheena Malhotra
2023-03-16 11:56:35 -07:00
committed by GitHub
parent f6fbceb5a0
commit 4d9cb17c93
11 changed files with 568 additions and 100 deletions

View File

@@ -26,6 +26,11 @@ using Microsoft.SqlTools.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using System.Diagnostics;
using Microsoft.SqlTools.Authentication.Sql;
using Microsoft.SqlTools.Credentials;
using Microsoft.SqlTools.Credentials.Contracts;
using Microsoft.SqlTools.Authentication;
using Microsoft.SqlTools.Shared.Utility;
using System.IO;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -37,6 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public const string AdminConnectionPrefix = "ADMIN:";
internal const string PasswordPlaceholder = "******";
private const string SqlAzureEdition = "SQL Azure";
public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens
public const int MaxServerlessReconnectTries = 5; // Max number of tries to wait for a serverless database to start up when its paused before giving up.
@@ -58,6 +64,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// </summary>
public static ConnectionService Instance => instance.Value;
/// <summary>
/// The authenticator instance for AAD MFA authentication needs.
/// </summary>
private IAuthenticator authenticator;
/// <summary>
/// The SQL connection factory object
/// </summary>
@@ -1072,11 +1083,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
{
this.ServiceHost = serviceHost;
if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
{
// Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
var provider = new AuthenticationProvider(commandOptions.ApplicationName, commandOptions.ApplicationPath);
var provider = new AuthenticationProvider(GetAuthenticator(commandOptions));
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
this.EnableSqlAuthenticationProvider = true;
@@ -1134,6 +1146,69 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
}
private IAuthenticator GetAuthenticator(CommandOptions commandOptions)
{
var applicationName = commandOptions.ApplicationName;
if (string.IsNullOrEmpty(applicationName))
{
applicationName = nameof(SqlTools);
Logger.Warning($"Application Name not received with command options, using default application name as: {applicationName}");
}
var applicationPath = commandOptions.ApplicationPath;
if (string.IsNullOrEmpty(applicationPath))
{
applicationPath = Utils.BuildAppDirectoryPath();
Logger.Warning($"Application Path not received with command options, using default application path as: {applicationPath}");
}
var cachePath = Path.Combine(applicationPath, applicationName, AzureTokenFolder);
return new Authenticator(new (ApplicationClientId, applicationName, cachePath, MsalCacheName), ReadCacheIvKey);
}
private void ReadCacheIvKey(out string? key, out string? iv)
{
Logger.Verbose("Reading Cached IV and Key from OS credential store.");
iv = null;
key = null;
try
{
// Read Cached Iv for MSAL cache (as Unicode)
Credential ivCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-iv"));
if (!string.IsNullOrEmpty(ivCred.Password))
{
iv = ivCred.Password;
}
else
{
throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-iv");
}
}
catch (Exception ex)
{
Logger.Error(ex);
}
try
{
// Read Cached Key for MSAL cache (as Unicode)
Credential keyCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-key"));
if (!string.IsNullOrEmpty(keyCred.Password))
{
key = keyCred.Password;
}
else
{
throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-key");
}
}
catch (Exception ex)
{
Logger.Error(ex);
}
}
private void RunConnectRequestHandlerTask(ConnectParams connectParams)
{
// create a task to connect asynchronously so that other requests are not blocked in the meantime