Add support for Azure Active Directory connections (#727)

This commit is contained in:
Matt Irvine
2018-11-13 11:50:30 -08:00
committed by GitHub
parent 2cb7f682c5
commit 7f28f249de
32 changed files with 291 additions and 121 deletions

View File

@@ -48,8 +48,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ValueType = ConnectionOption.ValueTypeCategory,
SpecialValueType = ConnectionOption.SpecialValueAuthType,
CategoryValues = new CategoryValue[]
{ new CategoryValue {DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue {DisplayName = "Windows Authentication", Name= "Integrated" }
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
},
IsIdentity = true,
IsRequired = true,

View File

@@ -523,7 +523,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
// create a sql connection instance
connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
connection = connectionInfo.Factory.CreateSqlConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken);
connectionInfo.AddConnection(connectionParams.Type, connection);
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
@@ -909,7 +909,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// Connect to master and query sys.databases
connectionDetails.DatabaseName = "master";
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails));
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken);
connection.Open();
List<string> results = new List<string>();
@@ -1151,6 +1151,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
break;
case "SqlLogin":
break;
case "AzureMFA":
connectionBuilder.UserID = "";
connectionBuilder.Password = "";
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
}
@@ -1387,7 +1391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
string connectionString = BuildConnectionString(info.ConnectionDetails);
// create a sql connection instance
DbConnection connection = info.Factory.CreateSqlConnection(connectionString);
DbConnection connection = info.Factory.CreateSqlConnection(connectionString, info.ConnectionDetails.AzureAccountToken);
connection.Open();
info.AddConnection(key, connection);
}
@@ -1488,6 +1492,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// Note: we need to audit all uses of this method to determine why we're
/// bypassing normal ConnectionService connection management
/// </summary>
/// <param name="connInfo">The connection info to connect with</param>
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
/// <returns>A SqlConnection created with the given connection info</returns>
internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null)
{
try
@@ -1515,6 +1522,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// open a dedicated binding server connection
SqlConnection sqlConn = new SqlConnection(connectionString);
// Fill in Azure authentication token if needed
if (connInfo.ConnectionDetails.AzureAccountToken != null)
{
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
sqlConn.Open();
return sqlConn;
}
@@ -1529,6 +1543,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return null;
}
/// <summary>
/// Create and open a new ServerConnection from a ConnectionInfo object.
/// This calls ConnectionService.OpenSqlConnection and then creates a
/// ServerConnection from it.
/// </summary>
/// <param name="connInfo">The connection info to connect with</param>
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
/// <returns>A ServerConnection (wrapping a SqlConnection) created with the given connection info</returns>
internal static ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName = null)
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
ServerConnection serverConnection;
if (connInfo.ConnectionDetails.AzureAccountToken != null)
{
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
}
else
{
serverConnection = new ServerConnection(sqlConnection);
}
return serverConnection;
}
public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false)
{
// verify that the connection is open
@@ -1552,4 +1590,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
}
}
public class AzureAccessToken : IRenewableToken
{
public DateTimeOffset TokenExpiry { get; set; }
public string Resource { get; set; }
public string Tenant { get; set; }
public string UserId { get; set; }
private string accessToken;
public AzureAccessToken(string accessToken)
{
this.accessToken = accessToken;
}
public string GetAccessToken()
{
return this.accessToken;
}
}
}

View File

@@ -497,6 +497,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
}
}
public string AzureAccountToken
{
get
{
return GetOptionValue<string>("azureAccountToken");
}
set
{
SetOptionValue("azureAccountToken", value);
}
}
public bool IsComparableTo(ConnectionDetails other)
{
if (other == null)
@@ -506,7 +518,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
if (ServerName != other.ServerName
|| AuthenticationType != other.AuthenticationType
|| UserName != other.UserName)
|| UserName != other.UserName
|| AzureAccountToken != other.AzureAccountToken)
{
return false;
}

View File

@@ -44,7 +44,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
PacketSize = details.PacketSize,
TypeSystemVersion = details.TypeSystemVersion,
ConnectionString = details.ConnectionString,
Port = details.Port
Port = details.Port,
AzureAccountToken = details.AzureAccountToken
};
}
}

View File

@@ -15,6 +15,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// <summary>
/// Create a new SQL Connection object
/// </summary>
DbConnection CreateSqlConnection(string connectionString);
DbConnection CreateSqlConnection(string connectionString, string azureAccountToken);
}
}

View File

@@ -46,10 +46,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// Opens the connection and sets the lock/command timeout and pooling=false.
/// </summary>
/// <returns>The opened connection</returns>
public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry)
public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry, string azureAccountToken)
{
csb.Pooling = false;
return OpenConnection(csb.ToString(), useRetry);
return OpenConnection(csb.ToString(), useRetry, azureAccountToken);
}
/// <summary>
@@ -57,7 +57,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// will assert if pooling!=false.
/// </summary>
/// <returns>The opened connection</returns>
public static IDbConnection OpenConnection(string connectionString, bool useRetry)
public static IDbConnection OpenConnection(string connectionString, bool useRetry, string azureAccountToken)
{
#if DEBUG
try
@@ -88,7 +88,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy();
}
ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
try
{
@@ -136,7 +136,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
SqlConnectionStringBuilder csb,
Action<IDbConnection> usingConnection,
Predicate<Exception> catchException,
bool useRetry)
bool useRetry,
string azureAccountToken)
{
Validate.IsNotNull(nameof(csb), csb);
Validate.IsNotNull(nameof(usingConnection), usingConnection);
@@ -145,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
{
// Always disable pooling
csb.Pooling = false;
using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry))
using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry, azureAccountToken))
{
usingConnection(conn);
}
@@ -228,7 +229,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
string commandText,
Action<IDbCommand> initializeCommand,
Predicate<Exception> catchException,
bool useRetry)
bool useRetry,
string azureAccountToken)
{
object retObject = null;
OpenConnection(
@@ -238,7 +240,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException);
},
catchException,
useRetry);
useRetry,
azureAccountToken);
return retObject;
}
@@ -636,7 +639,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// <summary>
/// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler.
/// </summary>
public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder)
public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder, string azureAccountToken)
{
Validate.IsNotNull(nameof(builder), builder);
@@ -670,7 +673,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
Logger.Write(TraceEventType.Error, ex.ToString());
return StandardExceptionHandler(ex); // handled
},
useRetry: true);
useRetry: true,
azureAccountToken: azureAccountToken);
return isDatabaseReadOnly;
}
@@ -697,7 +701,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
public string MachineName;
}
public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo)
public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo, string azureAccountToken)
{
serverInfo = null;
if (!TryGetConnectionStringBuilder(connectionString, out SqlConnectionStringBuilder builder))
@@ -705,14 +709,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
return false;
}
serverInfo = GetServerVersion(builder);
serverInfo = GetServerVersion(builder, azureAccountToken);
return true;
}
/// <summary>
/// Returns the version of the server. This routine will throw if an exception is encountered.
/// </summary>
public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb)
public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb, string azureAccountToken)
{
Validate.IsNotNull(nameof(csb), csb);
ServerInfo serverInfo = null;
@@ -724,7 +728,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
serverInfo = GetServerVersion(connection);
},
catchException: null, // Always throw
useRetry: true);
useRetry: true,
azureAccountToken: azureAccountToken);
return serverInfo;
}
@@ -1057,7 +1062,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// Returns true if the authenticating database is master, otherwise false. An example of
/// false is when the user is a contained user connecting to a contained database.
/// </summary>
public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder)
public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder, string azureAccountToken)
{
bool authIsMaster = true;
OpenConnection(
@@ -1067,7 +1072,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
authIsMaster = IsAuthenticatingDatabaseMaster(connection);
},
catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception
useRetry: true);
useRetry: true,
azureAccountToken: azureAccountToken);
return authIsMaster;
}

View File

@@ -59,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// <param name="connectionString">The connection string used to open the SQL Azure database.</param>
/// <param name="connectionRetryPolicy">The retry policy defining whether to retry a request if a connection fails to be established.</param>
/// <param name="commandRetryPolicy">The retry policy defining whether to retry a request if a command fails to be executed.</param>
public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy)
public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy, string azureAccountToken)
{
_underlyingConnection = new SqlConnection(connectionString);
_connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy();
@@ -68,6 +68,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
_underlyingConnection.StateChange += OnConnectionStateChange;
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
_commandRetryPolicy.RetryOccurred += RetryCommandCallback;
if (azureAccountToken != null)
{
_underlyingConnection.AccessToken = azureAccountToken;
}
}
/// <summary>

View File

@@ -18,11 +18,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// <summary>
/// Creates a new SqlConnection object
/// </summary>
public DbConnection CreateSqlConnection(string connectionString)
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
{
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
}
}
}