mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-14 01:25:40 -05:00
Introduce AAD interactive auth mode (#1860)
This commit is contained in:
@@ -23,7 +23,9 @@ using Microsoft.SqlTools.ServiceLayer.LanguageServices;
|
||||
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
|
||||
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||
using Microsoft.SqlTools.Utility;
|
||||
using static Microsoft.SqlTools.Shared.Utility.Constants;
|
||||
using System.Diagnostics;
|
||||
using Microsoft.SqlTools.Authentication.Sql;
|
||||
|
||||
namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
@@ -143,17 +145,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
RequestSecurityTokenParams message = new RequestSecurityTokenParams()
|
||||
{
|
||||
Authority = authority,
|
||||
Provider = "Azure",
|
||||
Authority = authority,
|
||||
Resource = resource,
|
||||
Scope = scope
|
||||
Scopes = new string[] { scope }
|
||||
};
|
||||
|
||||
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true);
|
||||
RequestSecurityTokenResponse response = await Instance.ServiceHost.SendRequest(SecurityTokenRequest.Type, message, true).ConfigureAwait(false);
|
||||
|
||||
return response.Token;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Enables configured 'Sql Authentication Provider' for 'Active Directory Interactive' authentication mode to be used
|
||||
/// when user chooses 'Azure MFA'.
|
||||
/// </summary>
|
||||
public bool EnableSqlAuthenticationProvider { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Returns a connection queue for given type
|
||||
/// </summary>
|
||||
@@ -248,7 +256,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
if (this.TryFindConnection(ownerUri, out connInfo))
|
||||
{
|
||||
// If not an azure connection, no need to refresh token
|
||||
if (connInfo.ConnectionDetails.AuthenticationType != "AzureMFA")
|
||||
if (connInfo.ConnectionDetails.AuthenticationType != AzureMFA)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -594,7 +602,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
connectionInfo.IsSqlDb = serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDatabase;
|
||||
connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
|
||||
// Determines that access token is used for creating connection.
|
||||
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == "AzureMFA";
|
||||
connectionInfo.IsAzureAuth = connectionInfo.ConnectionDetails.AuthenticationType == AzureMFA;
|
||||
connectionInfo.EngineEdition = (DatabaseEngineEdition)serverInfo.EngineEditionId;
|
||||
// Azure Data Studio supports SQL Server 2014 and later releases.
|
||||
response.IsSupportedVersion = serverInfo.IsCloud || serverInfo.ServerMajorVersion >= 12;
|
||||
@@ -1061,10 +1069,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
return handler.HandleRequest(this.connectionFactory, info);
|
||||
}
|
||||
|
||||
public void InitializeService(IProtocolEndpoint serviceHost)
|
||||
public void InitializeService(IProtocolEndpoint serviceHost, ServiceLayerCommandOptions commandOptions)
|
||||
{
|
||||
this.ServiceHost = serviceHost;
|
||||
|
||||
if (commandOptions != null && commandOptions.EnableSqlAuthenticationProvider)
|
||||
{
|
||||
// Register SqlAuthenticationProvider with SqlConnection for AAD Interactive (MFA) authentication.
|
||||
var provider = new AuthenticationProvider(commandOptions.ApplicationName);
|
||||
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryInteractive, provider);
|
||||
|
||||
this.EnableSqlAuthenticationProvider = true;
|
||||
Logger.Information("Registering implementation of SQL Authentication provider for 'Active Directory Interactive' authentication mode.");
|
||||
}
|
||||
|
||||
// Register request and event handlers with the Service Host
|
||||
serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest, true);
|
||||
serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest, true);
|
||||
@@ -1304,31 +1322,46 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
|
||||
connectionBuilder = new SqlConnectionStringBuilder
|
||||
{
|
||||
["Data Source"] = dataSource,
|
||||
["User Id"] = connectionDetails.UserName,
|
||||
["Password"] = connectionDetails.Password
|
||||
DataSource = dataSource
|
||||
};
|
||||
}
|
||||
|
||||
// Check for any optional parameters
|
||||
if (!string.IsNullOrEmpty(connectionDetails.DatabaseName))
|
||||
{
|
||||
connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
|
||||
connectionBuilder.InitialCatalog = connectionDetails.DatabaseName;
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType))
|
||||
{
|
||||
switch (connectionDetails.AuthenticationType)
|
||||
{
|
||||
case "Integrated":
|
||||
case Integrated:
|
||||
connectionBuilder.IntegratedSecurity = true;
|
||||
break;
|
||||
case "SqlLogin":
|
||||
case SqlLogin:
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionBuilder.Password = connectionDetails.Password;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.SqlPassword;
|
||||
break;
|
||||
case "AzureMFA":
|
||||
connectionBuilder.UserID = "";
|
||||
connectionBuilder.Password = "";
|
||||
case AzureMFA:
|
||||
if (Instance.EnableSqlAuthenticationProvider)
|
||||
{
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionDetails.AuthenticationType = ActiveDirectoryInteractive;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
|
||||
}
|
||||
else
|
||||
{
|
||||
connectionBuilder.UserID = "";
|
||||
}
|
||||
break;
|
||||
case "ActiveDirectoryPassword":
|
||||
case ActiveDirectoryInteractive:
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryInteractive;
|
||||
break;
|
||||
case ActiveDirectoryPassword:
|
||||
connectionBuilder.UserID = connectionDetails.UserName;
|
||||
connectionBuilder.Password = connectionDetails.Password;
|
||||
connectionBuilder.Authentication = SqlAuthenticationMethod.ActiveDirectoryPassword;
|
||||
break;
|
||||
default:
|
||||
@@ -1337,16 +1370,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting))
|
||||
{
|
||||
switch (connectionDetails.ColumnEncryptionSetting.ToUpper())
|
||||
if (Enum.TryParse<SqlConnectionColumnEncryptionSetting>(connectionDetails.ColumnEncryptionSetting, true, out var value))
|
||||
{
|
||||
case "ENABLED":
|
||||
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Enabled;
|
||||
break;
|
||||
case "DISABLED":
|
||||
connectionBuilder.ColumnEncryptionSetting = SqlConnectionColumnEncryptionSetting.Disabled;
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
|
||||
connectionBuilder.ColumnEncryptionSetting = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidColumnEncryptionSetting(connectionDetails.ColumnEncryptionSetting));
|
||||
}
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.SecureEnclaves))
|
||||
@@ -1365,36 +1395,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationProtocol))
|
||||
{
|
||||
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
|
||||
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|
||||
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
|
||||
}
|
||||
|
||||
switch (connectionDetails.EnclaveAttestationProtocol.ToUpper())
|
||||
if (Enum.TryParse<SqlConnectionAttestationProtocol>(connectionDetails.EnclaveAttestationProtocol, true, out var value))
|
||||
{
|
||||
case "AAS":
|
||||
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.AAS;
|
||||
break;
|
||||
case "HGS":
|
||||
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.HGS;
|
||||
break;
|
||||
case "NONE":
|
||||
connectionBuilder.AttestationProtocol = SqlConnectionAttestationProtocol.None;
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
|
||||
connectionBuilder.AttestationProtocol = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidEnclaveAttestationProtocol(connectionDetails.EnclaveAttestationProtocol));
|
||||
}
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.EnclaveAttestationUrl))
|
||||
{
|
||||
if (string.IsNullOrEmpty(connectionDetails.ColumnEncryptionSetting) || connectionDetails.ColumnEncryptionSetting.ToUpper() == "DISABLED"
|
||||
if (connectionBuilder.ColumnEncryptionSetting != SqlConnectionColumnEncryptionSetting.Enabled
|
||||
|| string.IsNullOrEmpty(connectionDetails.SecureEnclaves) || connectionDetails.SecureEnclaves.ToUpper() == "DISABLED")
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAlwaysEncryptedOptionCombination);
|
||||
}
|
||||
|
||||
if(connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
|
||||
if (connectionBuilder.AttestationProtocol == SqlConnectionAttestationProtocol.None)
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAttestationProtocolNoneWithUrl);
|
||||
}
|
||||
@@ -1456,19 +1480,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent))
|
||||
{
|
||||
ApplicationIntent intent;
|
||||
switch (connectionDetails.ApplicationIntent)
|
||||
if (Enum.TryParse<ApplicationIntent>(connectionDetails.ApplicationIntent, true, out ApplicationIntent value))
|
||||
{
|
||||
case "ReadOnly":
|
||||
intent = ApplicationIntent.ReadOnly;
|
||||
break;
|
||||
case "ReadWrite":
|
||||
intent = ApplicationIntent.ReadWrite;
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
|
||||
connectionBuilder.ApplicationIntent = value;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
|
||||
}
|
||||
connectionBuilder.ApplicationIntent = intent;
|
||||
}
|
||||
if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage))
|
||||
{
|
||||
@@ -1585,7 +1604,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
ApplicationIntent = builder.ApplicationIntent.ToString(),
|
||||
ApplicationName = builder.ApplicationName,
|
||||
AttachDbFilename = builder.AttachDBFilename,
|
||||
AuthenticationType = builder.IntegratedSecurity ? "Integrated" : "SqlLogin",
|
||||
AuthenticationType = builder.IntegratedSecurity ? "Integrated" :
|
||||
(builder.Authentication == SqlAuthenticationMethod.ActiveDirectoryInteractive
|
||||
? "ActiveDirectoryInteractive" : "SqlLogin"),
|
||||
ConnectRetryCount = builder.ConnectRetryCount,
|
||||
ConnectRetryInterval = builder.ConnectRetryInterval,
|
||||
ConnectTimeout = builder.ConnectTimeout,
|
||||
@@ -1793,7 +1814,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
SqlConnection sqlConn = new SqlConnection(connectionString);
|
||||
|
||||
// Fill in Azure authentication token if needed
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
|
||||
{
|
||||
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||
}
|
||||
@@ -1824,7 +1845,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
{
|
||||
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
|
||||
ServerConnection serverConnection;
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null && connInfo.ConnectionDetails.AuthenticationType == AzureMFA)
|
||||
{
|
||||
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user