Introduce AAD interactive auth mode (#1860)

This commit is contained in:
Cheena Malhotra
2023-03-02 09:39:54 -08:00
committed by GitHub
parent 98e50c98fe
commit 187b6ecc14
47 changed files with 918 additions and 151 deletions

View File

@@ -23,7 +23,9 @@ using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
using static Microsoft.SqlTools.Shared.Utility.Constants;
using System.Diagnostics;
using Microsoft.SqlTools.Authentication.Sql;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -143,17 +145,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
RequestSecurityTokenParams message = new RequestSecurityTokenParams()
{
Authority = authority,
Provider = "Azure",
Authority = authority,
Resource = resource,
Scope = scope
Scopes = new string[] { scope }
};
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true);
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true).ConfigureAwait(false);
return response.Token;
}
/// <summary>
/// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
/// when user chooses 'Azure MFA'.
/// </summary>
public bool EnableSqlAuthenticationProvider { get; set; }
/// <summary>
/// Returns a connection queue for given type
/// </summary>
@@ -248,7 +256,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
if (this.TryFindConnection(ownerUri, out connInfo))
{
// If not an azure connection, no need to refresh token
if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
if (connInfo.ConnectionDetails.AuthenticationType != AzureMFA)
{
return false;
}
@@ -594,7 +602,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connectionInfo.IsSqlDb = serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDatabase;
connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
// Determines that access token is used for creating connection.
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == "AzureMFA";
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == AzureMFA;
connectionInfo.EngineEdition = (DatabaseEngineEdition)serverInfo.EngineEditionId;
// Azure Data Studio supports SQL Server 2014 and later releases.
response.IsSupportedVersion = serverInfo.IsCloud || serverInfo.ServerMajorVersion >= 12;
@@ -1061,10 +1069,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return handler.HandleRequest(this.connectionFactory, info);
}
public void InitializeService(IProtocolEndpoint serviceHost)
public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
{
this.ServiceHost = serviceHost;
if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
{
// Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
var provider = new AuthenticationProvider(commandOptions.ApplicationName);
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
this.EnableSqlAuthenticationProvider = true;
Logger.Information("Registering implementation of SQL Authentication provider for 'Active Directory Interactive' authentication mode.");
}
// Register request and event handlers with the Service Host
serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest, true);
serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest, true);
@@ -1304,31 +1322,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connectionBuilder = new SqlConnectionStringBuilder
{
["Data Source"] = dataSource,
["User Id"] = connectionDetails.UserName,
["Password"] = connectionDetails.Password
DataSource = dataSource
};
}
// Check for any optional parameters
if (!string.IsNullOrEmpty(connectionDetails.DatabaseName))
{
connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
connectionBuilder.InitialCatalog = connectionDetails.DatabaseName;
}
if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType))
{
switch (connectionDetails.AuthenticationType)
{
case "Integrated":
case Integrated:
connectionBuilder.IntegratedSecurity = true;
break;
case "SqlLogin":
case SqlLogin:
connectionBuilder.UserID = connectionDetails.UserName;
connectionBuilder.Password = connectionDetails.Password;
connectionBuilder.Authentication = SqlAuthenticationMethod.SqlPassword;
break;
case "AzureMFA":
connectionBuilder.UserID = "";
connectionBuilder.Password = "";
case AzureMFA:
if (Instance.EnableSqlAuthenticationProvider)
{
connectionBuilder.UserID = connectionDetails.UserName;
connectionDetails.AuthenticationType = ActiveDirectoryInteractive;
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
}
else
{
connectionBuilder.UserID = "";
}
break;
case "ActiveDirectoryPassword":
case ActiveDirectoryInteractive:
connectionBuilder.UserID = connectionDetails.UserName;
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
break;
case ActiveDirectoryPassword:
connectionBuilder.UserID = connectionDetails.UserName;
connectionBuilder.Password = connectionDetails.Password;
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryPassword;
break;
default:
@@ -1337,16 +1370,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting))
{
switch (connectionDetails.ColumnEncryptionSetting.ToUpper())
if (Enum.TryParse<SqlConnectionColumnEncryptionSetting>(connectionDetails.ColumnEncryptionSetting, true, out var value))
{
case "ENABLED":
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Enabled;
break;
case "DISABLED":
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled;
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
connectionBuilder.ColumnEncryptionSetting = value;
}
else
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
}
}
if (!string.IsNullOrEmpty(connectionDetails.SecureEnclaves))
@@ -1365,36 +1395,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationProtocol))
{
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
}
switch (connectionDetails.EnclaveAttestationProtocol.ToUpper())
if (Enum.TryParse<SqlConnectionAttestationProtocol>(connectionDetails.EnclaveAttestationProtocol, true, out var value))
{
case "AAS":
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.AAS;
break;
case "HGS":
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.HGS;
break;
case "NONE":
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.None;
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
connectionBuilder.AttestationProtocol = value;
}
else
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
}
}
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationUrl))
{
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
}
if(connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
if (connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAttestationProtocolNoneWithUrl);
}
@@ -1456,19 +1480,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent))
{
ApplicationIntent intent;
switch (connectionDetails.ApplicationIntent)
if (Enum.TryParse<ApplicationIntent>(connectionDetails.ApplicationIntent, true, out ApplicationIntent value))
{
case "ReadOnly":
intent = ApplicationIntent.ReadOnly;
break;
case "ReadWrite":
intent = ApplicationIntent.ReadWrite;
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
connectionBuilder.ApplicationIntent = value;
}
else
{
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
}
connectionBuilder.ApplicationIntent = intent;
}
if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage))
{
@@ -1585,7 +1604,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ApplicationIntent = builder.ApplicationIntent.ToString(),
ApplicationName = builder.ApplicationName,
AttachDbFilename = builder.AttachDBFilename,
AuthenticationType = builder.IntegratedSecurity ? "Integrated" : "SqlLogin",
AuthenticationType = builder.IntegratedSecurity ? "Integrated" :
(builder.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
? "ActiveDirectoryInteractive" : "SqlLogin"),
ConnectRetryCount = builder.ConnectRetryCount,
ConnectRetryInterval = builder.ConnectRetryInterval,
ConnectTimeout = builder.ConnectTimeout,
@@ -1793,7 +1814,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
SqlConnection sqlConn = new SqlConnection(connectionString);
// Fill in Azure authentication token if needed
if (connInfo.ConnectionDetails.AzureAccountToken != null)
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
{
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
@@ -1824,7 +1845,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
ServerConnection serverConnection;
if (connInfo.ConnectionDetails.AzureAccountToken != null)
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
{
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
}