mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-14 01:25:40 -05:00
MSAL encrypted file system cache (#1945)
This commit is contained in:
@@ -26,6 +26,11 @@ using Microsoft.SqlTools.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using System.Diagnostics;
|
||||
using Microsoft.SqlTools.Authentication.Sql;
|
||||
using Microsoft.SqlTools.Credentials;
|
||||
using Microsoft.SqlTools.Credentials.Contracts;
|
||||
using Microsoft.SqlTools.Authentication;
|
||||
using Microsoft.SqlTools.Shared.Utility;
|
||||
using System.IO;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
@@ -37,6 +42,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
public const string AdminConnectionPrefix = "ADMIN:";
|
||||
internal const string PasswordPlaceholder = "******";
|
||||
private const string SqlAzureEdition = "SQL Azure";
|
||||
|
||||
public const int MaxTolerance = 2 * 60; // two minutes - standard tolerance across ADS for AAD tokens
|
||||
|
||||
public const int MaxServerlessReconnectTries = 5; // Max number of tries to wait for a serverless database to start up when its paused before giving up.
|
||||
@@ -58,6 +64,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
/// </summary>
|
||||
public static ConnectionService Instance => instance.Value;
|
||||
|
||||
/// <summary>
|
||||
/// The authenticator instance for AAD MFA authentication needs.
|
||||
/// </summary>
|
||||
private IAuthenticator authenticator;
|
||||
|
||||
/// <summary>
|
||||
/// The SQL connection factory object
|
||||
/// </summary>
|
||||
@@ -1072,11 +1083,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
|
||||
{
|
||||
this.ServiceHost = serviceHost;
|
||||
|
||||
|
||||
if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
|
||||
{
|
||||
|
||||
// Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
|
||||
var provider = new AuthenticationProvider(commandOptions.ApplicationName, commandOptions.ApplicationPath);
|
||||
var provider = new AuthenticationProvider(GetAuthenticator(commandOptions));
|
||||
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
|
||||
|
||||
this.EnableSqlAuthenticationProvider = true;
|
||||
@@ -1134,6 +1146,69 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
}
|
||||
|
||||
private IAuthenticator GetAuthenticator(CommandOptions commandOptions)
|
||||
{
|
||||
var applicationName = commandOptions.ApplicationName;
|
||||
if (string.IsNullOrEmpty(applicationName))
|
||||
{
|
||||
applicationName = nameof(SqlTools);
|
||||
Logger.Warning($"Application Name not received with command options, using default application name as: {applicationName}");
|
||||
}
|
||||
|
||||
var applicationPath = commandOptions.ApplicationPath;
|
||||
if (string.IsNullOrEmpty(applicationPath))
|
||||
{
|
||||
applicationPath = Utils.BuildAppDirectoryPath();
|
||||
Logger.Warning($"Application Path not received with command options, using default application path as: {applicationPath}");
|
||||
}
|
||||
|
||||
var cachePath = Path.Combine(applicationPath, applicationName, AzureTokenFolder);
|
||||
return new Authenticator(new (ApplicationClientId, applicationName, cachePath, MsalCacheName), ReadCacheIvKey);
|
||||
}
|
||||
|
||||
private void ReadCacheIvKey(out string? key, out string? iv)
|
||||
{
|
||||
Logger.Verbose("Reading Cached IV and Key from OS credential store.");
|
||||
|
||||
iv = null;
|
||||
key = null;
|
||||
try
|
||||
{
|
||||
// Read Cached Iv for MSAL cache (as Unicode)
|
||||
Credential ivCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-iv"));
|
||||
if (!string.IsNullOrEmpty(ivCred.Password))
|
||||
{
|
||||
iv = ivCred.Password;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-iv");
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Logger.Error(ex);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
// Read Cached Key for MSAL cache (as Unicode)
|
||||
Credential keyCred = CredentialService.Instance.ReadCredential(new($"{AzureAccountProviderCredentials}|{MsalCacheName}-key"));
|
||||
if (!string.IsNullOrEmpty(keyCred.Password))
|
||||
{
|
||||
key = keyCred.Password;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new Exception($"Could not read credential: {AzureAccountProviderCredentials}|{MsalCacheName}-key");
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Logger.Error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void RunConnectRequestHandlerTask(ConnectParams connectParams)
|
||||
{
|
||||
// create a task to connect asynchronously so that other requests are not blocked in the meantime
|
||||
|
||||
Reference in New Issue
Block a user