mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-16 01:25:41 -05:00
Add support for Azure Active Directory connections (#727)
This commit is contained in:
@@ -523,7 +523,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
|
||||
|
||||
// create a sql connection instance
|
||||
connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
|
||||
connection = connectionInfo.Factory.CreateSqlConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken);
|
||||
connectionInfo.AddConnection(connectionParams.Type, connection);
|
||||
|
||||
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
|
||||
@@ -909,7 +909,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
|
||||
// Connect to master and query sys.databases
|
||||
connectionDetails.DatabaseName = "master";
|
||||
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails));
|
||||
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken);
|
||||
connection.Open();
|
||||
|
||||
List<string> results = new List<string>();
|
||||
@@ -1151,6 +1151,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
break;
|
||||
case "SqlLogin":
|
||||
break;
|
||||
case "AzureMFA":
|
||||
connectionBuilder.UserID = "";
|
||||
connectionBuilder.Password = "";
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
|
||||
}
|
||||
@@ -1387,7 +1391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
string connectionString = BuildConnectionString(info.ConnectionDetails);
|
||||
|
||||
// create a sql connection instance
|
||||
DbConnection connection = info.Factory.CreateSqlConnection(connectionString);
|
||||
DbConnection connection = info.Factory.CreateSqlConnection(connectionString, info.ConnectionDetails.AzureAccountToken);
|
||||
connection.Open();
|
||||
info.AddConnection(key, connection);
|
||||
}
|
||||
@@ -1488,6 +1492,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
/// Note: we need to audit all uses of this method to determine why we're
|
||||
/// bypassing normal ConnectionService connection management
|
||||
/// </summary>
|
||||
/// <param name="connInfo">The connection info to connect with</param>
|
||||
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
|
||||
/// <returns>A SqlConnection created with the given connection info</returns>
|
||||
internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null)
|
||||
{
|
||||
try
|
||||
@@ -1515,6 +1522,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
|
||||
// open a dedicated binding server connection
|
||||
SqlConnection sqlConn = new SqlConnection(connectionString);
|
||||
|
||||
// Fill in Azure authentication token if needed
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||
{
|
||||
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||
}
|
||||
|
||||
sqlConn.Open();
|
||||
return sqlConn;
|
||||
}
|
||||
@@ -1529,6 +1543,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
return null;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Create and open a new ServerConnection from a ConnectionInfo object.
|
||||
/// This calls ConnectionService.OpenSqlConnection and then creates a
|
||||
/// ServerConnection from it.
|
||||
/// </summary>
|
||||
/// <param name="connInfo">The connection info to connect with</param>
|
||||
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
|
||||
/// <returns>A ServerConnection (wrapping a SqlConnection) created with the given connection info</returns>
|
||||
internal static ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName = null)
|
||||
{
|
||||
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
|
||||
ServerConnection serverConnection;
|
||||
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||
{
|
||||
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
|
||||
}
|
||||
else
|
||||
{
|
||||
serverConnection = new ServerConnection(sqlConnection);
|
||||
}
|
||||
|
||||
return serverConnection;
|
||||
}
|
||||
|
||||
public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false)
|
||||
{
|
||||
// verify that the connection is open
|
||||
@@ -1552,4 +1590,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class AzureAccessToken : IRenewableToken
|
||||
{
|
||||
public DateTimeOffset TokenExpiry { get; set; }
|
||||
public string Resource { get; set; }
|
||||
public string Tenant { get; set; }
|
||||
public string UserId { get; set; }
|
||||
|
||||
private string accessToken;
|
||||
|
||||
public AzureAccessToken(string accessToken)
|
||||
{
|
||||
this.accessToken = accessToken;
|
||||
}
|
||||
|
||||
public string GetAccessToken()
|
||||
{
|
||||
return this.accessToken;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user