Add support for Azure Active Directory connections (#727)

This commit is contained in:
Matt Irvine
2018-11-13 11:50:30 -08:00
committed by GitHub
parent 2cb7f682c5
commit 7f28f249de
32 changed files with 291 additions and 121 deletions

View File

@@ -143,8 +143,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
if (connInfo != null)
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo);
var serverConnection = new ServerConnection(sqlConnection);
var serverConnection = ConnectionService.OpenServerConnection(connInfo);
var fetcher = new JobFetcher(serverConnection);
var filter = new JobActivityFilter();
var jobs = fetcher.FetchJobs(filter);
@@ -158,7 +157,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
}
result.Success = true;
result.Jobs = agentJobs.ToArray();
sqlConnection.Close();
serverConnection.SqlConnectionObject.Close();
}
await requestContext.SendResult(result);
}
@@ -269,8 +268,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
out connInfo);
if (connInfo != null)
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo);
var serverConnection = new ServerConnection(sqlConnection);
var serverConnection = ConnectionService.OpenServerConnection(connInfo);
var jobHelper = new JobHelper(serverConnection);
jobHelper.JobName = parameters.JobName;
switch(parameters.Action)
@@ -1163,8 +1161,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
private Tuple<SqlConnectionInfo, DataTable, ServerConnection> CreateSqlConnection(ConnectionInfo connInfo, String jobId)
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo);
var serverConnection = new ServerConnection(sqlConnection);
var serverConnection = ConnectionService.OpenServerConnection(connInfo);
var server = new Server(serverConnection);
var filter = new JobHistoryFilter();
filter.JobID = new Guid(jobId);

View File

@@ -48,8 +48,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ValueType = ConnectionOption.ValueTypeCategory,
SpecialValueType = ConnectionOption.SpecialValueAuthType,
CategoryValues = new CategoryValue[]
{ new CategoryValue {DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue {DisplayName = "Windows Authentication", Name= "Integrated" }
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
},
IsIdentity = true,
IsRequired = true,

View File

@@ -523,7 +523,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
// create a sql connection instance
connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
connection = connectionInfo.Factory.CreateSqlConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken);
connectionInfo.AddConnection(connectionParams.Type, connection);
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
@@ -909,7 +909,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// Connect to master and query sys.databases
connectionDetails.DatabaseName = "master";
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails));
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken);
connection.Open();
List<string> results = new List<string>();
@@ -1151,6 +1151,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
break;
case "SqlLogin":
break;
case "AzureMFA":
connectionBuilder.UserID = "";
connectionBuilder.Password = "";
break;
default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
}
@@ -1387,7 +1391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
string connectionString = BuildConnectionString(info.ConnectionDetails);
// create a sql connection instance
DbConnection connection = info.Factory.CreateSqlConnection(connectionString);
DbConnection connection = info.Factory.CreateSqlConnection(connectionString, info.ConnectionDetails.AzureAccountToken);
connection.Open();
info.AddConnection(key, connection);
}
@@ -1488,6 +1492,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// Note: we need to audit all uses of this method to determine why we're
/// bypassing normal ConnectionService connection management
/// </summary>
/// <param name="connInfo">The connection info to connect with</param>
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
/// <returns>A SqlConnection created with the given connection info</returns>
internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null)
{
try
@@ -1515,6 +1522,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// open a dedicated binding server connection
SqlConnection sqlConn = new SqlConnection(connectionString);
// Fill in Azure authentication token if needed
if (connInfo.ConnectionDetails.AzureAccountToken != null)
{
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
sqlConn.Open();
return sqlConn;
}
@@ -1529,6 +1543,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return null;
}
/// <summary>
/// Create and open a new ServerConnection from a ConnectionInfo object.
/// This calls ConnectionService.OpenSqlConnection and then creates a
/// ServerConnection from it.
/// </summary>
/// <param name="connInfo">The connection info to connect with</param>
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
/// <returns>A ServerConnection (wrapping a SqlConnection) created with the given connection info</returns>
internal static ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName = null)
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
ServerConnection serverConnection;
if (connInfo.ConnectionDetails.AzureAccountToken != null)
{
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
}
else
{
serverConnection = new ServerConnection(sqlConnection);
}
return serverConnection;
}
public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false)
{
// verify that the connection is open
@@ -1552,4 +1590,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
}
}
public class AzureAccessToken : IRenewableToken
{
public DateTimeOffset TokenExpiry { get; set; }
public string Resource { get; set; }
public string Tenant { get; set; }
public string UserId { get; set; }
private string accessToken;
public AzureAccessToken(string accessToken)
{
this.accessToken = accessToken;
}
public string GetAccessToken()
{
return this.accessToken;
}
}
}

View File

@@ -497,6 +497,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
}
}
public string AzureAccountToken
{
get
{
return GetOptionValue<string>("azureAccountToken");
}
set
{
SetOptionValue("azureAccountToken", value);
}
}
public bool IsComparableTo(ConnectionDetails other)
{
if (other == null)
@@ -506,7 +518,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
if (ServerName != other.ServerName
|| AuthenticationType != other.AuthenticationType
|| UserName != other.UserName)
|| UserName != other.UserName
|| AzureAccountToken != other.AzureAccountToken)
{
return false;
}

View File

@@ -44,7 +44,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
PacketSize = details.PacketSize,
TypeSystemVersion = details.TypeSystemVersion,
ConnectionString = details.ConnectionString,
Port = details.Port
Port = details.Port,
AzureAccountToken = details.AzureAccountToken
};
}
}

View File

@@ -15,6 +15,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// <summary>
/// Create a new SQL Connection object
/// </summary>
DbConnection CreateSqlConnection(string connectionString);
DbConnection CreateSqlConnection(string connectionString, string azureAccountToken);
}
}

View File

@@ -46,10 +46,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// Opens the connection and sets the lock/command timeout and pooling=false.
/// </summary>
/// <returns>The opened connection</returns>
public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry)
public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry, string azureAccountToken)
{
csb.Pooling = false;
return OpenConnection(csb.ToString(), useRetry);
return OpenConnection(csb.ToString(), useRetry, azureAccountToken);
}
/// <summary>
@@ -57,7 +57,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// will assert if pooling!=false.
/// </summary>
/// <returns>The opened connection</returns>
public static IDbConnection OpenConnection(string connectionString, bool useRetry)
public static IDbConnection OpenConnection(string connectionString, bool useRetry, string azureAccountToken)
{
#if DEBUG
try
@@ -88,7 +88,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy();
}
ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
try
{
@@ -136,7 +136,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
SqlConnectionStringBuilder csb,
Action<IDbConnection> usingConnection,
Predicate<Exception> catchException,
bool useRetry)
bool useRetry,
string azureAccountToken)
{
Validate.IsNotNull(nameof(csb), csb);
Validate.IsNotNull(nameof(usingConnection), usingConnection);
@@ -145,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
{
// Always disable pooling
csb.Pooling = false;
using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry))
using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry, azureAccountToken))
{
usingConnection(conn);
}
@@ -228,7 +229,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
string commandText,
Action<IDbCommand> initializeCommand,
Predicate<Exception> catchException,
bool useRetry)
bool useRetry,
string azureAccountToken)
{
object retObject = null;
OpenConnection(
@@ -238,7 +240,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException);
},
catchException,
useRetry);
useRetry,
azureAccountToken);
return retObject;
}
@@ -636,7 +639,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// <summary>
/// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler.
/// </summary>
public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder)
public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder, string azureAccountToken)
{
Validate.IsNotNull(nameof(builder), builder);
@@ -670,7 +673,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
Logger.Write(TraceEventType.Error, ex.ToString());
return StandardExceptionHandler(ex); // handled
},
useRetry: true);
useRetry: true,
azureAccountToken: azureAccountToken);
return isDatabaseReadOnly;
}
@@ -697,7 +701,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
public string MachineName;
}
public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo)
public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo, string azureAccountToken)
{
serverInfo = null;
if (!TryGetConnectionStringBuilder(connectionString, out SqlConnectionStringBuilder builder))
@@ -705,14 +709,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
return false;
}
serverInfo = GetServerVersion(builder);
serverInfo = GetServerVersion(builder, azureAccountToken);
return true;
}
/// <summary>
/// Returns the version of the server. This routine will throw if an exception is encountered.
/// </summary>
public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb)
public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb, string azureAccountToken)
{
Validate.IsNotNull(nameof(csb), csb);
ServerInfo serverInfo = null;
@@ -724,7 +728,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
serverInfo = GetServerVersion(connection);
},
catchException: null, // Always throw
useRetry: true);
useRetry: true,
azureAccountToken: azureAccountToken);
return serverInfo;
}
@@ -1057,7 +1062,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// Returns true if the authenticating database is master, otherwise false. An example of
/// false is when the user is a contained user connecting to a contained database.
/// </summary>
public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder)
public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder, string azureAccountToken)
{
bool authIsMaster = true;
OpenConnection(
@@ -1067,7 +1072,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
authIsMaster = IsAuthenticatingDatabaseMaster(connection);
},
catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception
useRetry: true);
useRetry: true,
azureAccountToken: azureAccountToken);
return authIsMaster;
}

View File

@@ -59,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// <param name="connectionString">The connection string used to open the SQL Azure database.</param>
/// <param name="connectionRetryPolicy">The retry policy defining whether to retry a request if a connection fails to be established.</param>
/// <param name="commandRetryPolicy">The retry policy defining whether to retry a request if a command fails to be executed.</param>
public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy)
public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy, string azureAccountToken)
{
_underlyingConnection = new SqlConnection(connectionString);
_connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy();
@@ -68,6 +68,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
_underlyingConnection.StateChange += OnConnectionStateChange;
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
_commandRetryPolicy.RetryOccurred += RetryCommandCallback;
if (azureAccountToken != null)
{
_underlyingConnection.AccessToken = azureAccountToken;
}
}
/// <summary>

View File

@@ -18,11 +18,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// <summary>
/// Creates a new SqlConnection object
/// </summary>
public DbConnection CreateSqlConnection(string connectionString)
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
{
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
}
}
}

View File

@@ -213,8 +213,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
if (connInfo != null)
{
SqlConnection connection = ConnectionService.OpenSqlConnection(connInfo, "Restore");
Server server = new Server(new ServerConnection(connection));
Server server = new Server(ConnectionService.OpenServerConnection(connInfo, "Restore"));
RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, targetDatabaseName);
return restoreDataObject;

View File

@@ -9,6 +9,7 @@ using System.Data.Common;
using System.Data.SqlClient;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
using Microsoft.SqlTools.Utility;
@@ -56,7 +57,16 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
}
// Connect with SMO and get the metadata for the table
Server server = new Server(new ServerConnection(sqlConn));
ServerConnection serverConnection;
if (sqlConn.AccessToken == null)
{
serverConnection = new ServerConnection(sqlConn);
}
else
{
serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken));
}
Server server = new Server(serverConnection);
Database db = new Database(server, sqlConn.Database);
TableViewTableTypeBase smoResult;

View File

@@ -37,9 +37,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
/// <summary>
/// Virtual method used to support mocking and testing
/// </summary>
public virtual SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName)
public virtual ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName)
{
return ConnectionService.OpenSqlConnection(connInfo, featureName);
return ConnectionService.OpenServerConnection(connInfo, featureName);
}
}
@@ -198,10 +198,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
try
{
bindingContext.BindingLock.Reset();
SqlConnection sqlConn = connectionOpener.OpenSqlConnection(connInfo, featureName);
// populate the binding context to work with the SMO metadata provider
bindingContext.ServerConnection = new ServerConnection(sqlConn);
bindingContext.ServerConnection = connectionOpener.OpenServerConnection(connInfo, featureName);
if (this.needsMetadata)
{

View File

@@ -682,7 +682,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
/// <param name="userName">User name for not trusted connections</param>
/// <param name="password">Password for not trusted connections</param>
/// <param name="xmlParameters">XML string with parameters</param>
public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters)
public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters, string azureAccountToken = null)
{
this.serverType = serverType;
this.serverName = serverName;
@@ -690,7 +690,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
if (serverType == ServerType.SQL)
{
// does some extra initialization
ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName), true);
ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName, azureAccountToken), true);
// NOTE: ServerConnection property will constuct the object if needed
m_server = new Server(ServerConnection);
@@ -1024,7 +1024,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
bool trusted,
string userName,
SecureString password,
string databaseName)
string databaseName,
string azureAccountToken)
{
SqlConnectionInfoWithConnection tempCI = new SqlConnectionInfoWithConnection(serverName);
tempCI.SingleConnection = false;
@@ -1040,6 +1041,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
tempCI.UserName = userName;
tempCI.SecurePassword = password;
}
tempCI.DatabaseName = databaseName;
return tempCI;
@@ -1220,39 +1222,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
containerDoc = CreateDataContainerDocument(connInfo, databaseExists);
}
CDataContainer dataContainer;
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
// add alternate port to server name property if provided
var connectionDetails = connInfo.ConnectionDetails;
string serverName = !connectionDetails.Port.HasValue
? connectionDetails.ServerName
: string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value);
// check if the connection is using SQL Auth or Integrated Auth
// TODO: ConnectionQueue try to get an existing connection (ConnectionQueue)
if (string.Equals(connectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase))
{
var passwordSecureString = BuildSecureStringFromPassword(connectionDetails.Password);
dataContainer = new CDataContainer(
CDataContainer.ServerType.SQL,
serverName,
false,
connectionDetails.UserName,
passwordSecureString,
connectionDetails.DatabaseName,
containerDoc.InnerXml);
}
else
{
dataContainer = new CDataContainer(
CDataContainer.ServerType.SQL,
serverName,
true,
null,
null,
connectionDetails.DatabaseName,
containerDoc.InnerXml);
}
var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
connectionInfoWithConnection.ServerConnection = serverConnection;
CDataContainer dataContainer = new CDataContainer(ServerType.SQL, connectionInfoWithConnection, true);
dataContainer.Init(containerDoc);
return dataContainer;
}

View File

@@ -9,6 +9,7 @@ using System.Data.Common;
using System.Data.SqlClient;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
@@ -60,7 +61,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata
}
// Connect with SMO and get the metadata for the table
Server server = new Server(new ServerConnection(sqlConn));
ServerConnection serverConnection;
if (sqlConn.AccessToken == null)
{
serverConnection = new ServerConnection(sqlConn);
}
else
{
serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken));
}
Server server = new Server(serverConnection);
Database database = server.Databases[sqlConn.Database];
TableViewTableTypeBase smoResult;
switch (objectType.ToLowerInvariant())

View File

@@ -21,9 +21,9 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
</ItemGroup>
<ItemGroup>
<Compile Include="**\*.cs" />

View File

@@ -6,6 +6,7 @@
using System;
using System.Collections.Generic;
using System.Data.SqlClient;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility;
using Microsoft.SqlServer.Management.Common;
@@ -42,10 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
ServerConnection = serverConnection;
}
public ScriptAsScriptingOperation(ScriptingParams parameters) : base(parameters)
public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccountToken) : base(parameters)
{
SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString);
if (azureAccountToken != null)
{
sqlConnection.AccessToken = azureAccountToken;
}
ServerConnection = new ServerConnection(sqlConnection);
if (azureAccountToken != null)
{
ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken);
}
disconnectAtDispose = true;
}

View File

@@ -26,8 +26,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
private int eventSequenceNumber = 1;
public ScriptingScriptOperation(ScriptingParams parameters): base(parameters)
private string azureAccessToken;
public ScriptingScriptOperation(ScriptingParams parameters, string azureAccessToken): base(parameters)
{
this.azureAccessToken = azureAccessToken;
}
public override void Execute()
@@ -200,7 +203,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
selectedObjects.Count(),
string.Join(", ", selectedObjects)));
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString);
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString, this.azureAccessToken);
string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog;
foreach (ScriptingObject scriptingObject in selectedObjects)

View File

@@ -111,12 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
// use the owner uri property to lookup its associated ConnectionInfo
// and then build a connection string out of that
ConnectionInfo connInfo = null;
string accessToken = null;
if (parameters.ConnectionString == null)
{
ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
if (connInfo != null)
{
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
else
{
@@ -126,11 +128,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
if (!ShouldCreateScriptAsOperation(parameters))
{
operation = new ScriptingScriptOperation(parameters);
operation = new ScriptingScriptOperation(parameters, accessToken);
}
else
{
operation = new ScriptAsScriptingOperation(parameters);
operation = new ScriptAsScriptingOperation(parameters, accessToken);
}
operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait();

View File

@@ -4,6 +4,7 @@
//
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility;
using System;
@@ -71,17 +72,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
parameters.OperationId = this.OperationId;
}
protected string GetServerNameFromLiveInstance(string connectionString)
protected string GetServerNameFromLiveInstance(string connectionString, string azureAccessToken)
{
string serverName = null;
using (SqlConnection connection = new SqlConnection(connectionString))
{
if (azureAccessToken != null)
{
connection.AccessToken = azureAccessToken;
}
connection.Open();
try
{
ServerConnection serverConnection = new ServerConnection(connection);
ServerConnection serverConnection;
if (azureAccessToken == null)
{
serverConnection = new ServerConnection(connection);
}
else
{
serverConnection = new ServerConnection(connection, new AzureAccessToken(azureAccessToken));
}
serverName = serverConnection.TrueName;
}
catch (SqlException e)

View File

@@ -258,7 +258,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy);
ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy, null);
return connection;
}
@@ -283,7 +283,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn);
},
catchException: null,
useRetry: false);
useRetry: false,
azureAccountToken: null);
Assert.False(string.IsNullOrWhiteSpace(filePath));
Assert.False(string.IsNullOrWhiteSpace(logPath));
@@ -342,7 +343,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
var connectionBuilder = CreateTestConnectionStringBuilder();
Assert.NotNull(connectionBuilder);
bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder);
bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder, null);
Assert.False(isReadOnly);
}
@@ -352,7 +353,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
[Fact]
public void TestIsDatabaseReadonlyWithNullBuilder()
{
Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null));
Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null, null));
}
/// <summary>
@@ -361,7 +362,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
[Fact]
public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed()
{
using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true))
using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true, azureAccountToken: null))
{
VerifySessionSettings(conn, true);
VerifySessionSettings(conn, false);
@@ -506,7 +507,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
"SET NOCOUNT ON; SET NOCOUNT OFF;",
ReliableConnectionHelper.SetCommandTimeout,
null,
true
true,
null
);
Assert.NotNull(result);
}
@@ -519,7 +521,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
{
ReliableConnectionHelper.ServerInfo info = null;
var connBuilder = CreateTestConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info));
Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info, null));
Assert.NotNull(info);
Assert.NotNull(info.ServerVersion);
@@ -535,7 +537,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
RunIfWrapper.RunIfWindows(() =>
{
ReliableConnectionHelper.ServerInfo info = null;
Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info));
Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info, null));
});
}
@@ -679,12 +681,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath);
ConnectionInfo connInfo = result.ConnectionInfo;
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
connection.Open();
Assert.True(connection.State == ConnectionState.Open, "Connection should be open.");
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder));
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder, null));
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));

View File

@@ -33,7 +33,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -12,7 +12,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -12,7 +12,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -80,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object);
@@ -146,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
.Returns(() => Task.Run(() => {}));
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object)
.Returns(mockConnection2.Object);
@@ -209,7 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object);
@@ -282,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
connectionMock.Setup(c => c.Database).Returns(expectedDbName);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(connectionMock.Object);
var connectionService = new ConnectionService(mockFactory.Object);
@@ -321,8 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var dummySqlConnection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
dummySqlConnection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -775,7 +775,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// Setup mock connection factory to inject query results
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(CreateMockDbConnection(new[] {data}));
var connectionService = new ConnectionService(mockFactory.Object);
@@ -1324,8 +1324,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1374,8 +1374,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1427,8 +1427,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString, string azureAccountToken) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1484,5 +1484,32 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.Equal(false, details.TrustServerCertificate);
Assert.Equal(30, details.ConnectTimeout);
}
[Fact]
public async void ConnectingWithAzureAccountUsesToken()
{
// Set up mock connection factory
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(new TestSqlConnection(null));
var connectionService = new ConnectionService(mockFactory.Object);
var details = TestObjects.GetTestConnectionDetails();
var azureAccountToken = "testAzureAccountToken";
details.AzureAccountToken = azureAccountToken;
details.UserName = "";
details.Password = "";
details.AuthenticationType = "AzureMFA";
// If I open a connection using connection details that include an account token
await connectionService.Connect(new ConnectParams
{
OwnerUri = "testURI",
Connection = details
});
// Then the connection factory got called with details including an account token
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
}
}
}

View File

@@ -0,0 +1,36 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{
/// <summary>
/// Tests for ReliableConnection code
/// </summary>
public class ReliableConnectionTests
{
[Fact]
public void ReliableSqlConnectionUsesAzureToken()
{
ConnectionDetails details = TestObjects.GetTestConnectionDetails();
details.UserName = "";
details.Password = "";
string connectionString = ConnectionService.BuildConnectionString(details);
string azureAccountToken = "testAzureAccountToken";
RetryPolicy retryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
// If I create a ReliableSqlConnection using an azure account token
var reliableConnection = new ReliableSqlConnection(connectionString, retryPolicy, retryPolicy, azureAccountToken);
// Then the connection's azureAccountToken gets set
Assert.Equal(azureAccountToken, reliableConnection.GetUnderlyingConnection().AccessToken);
}
}
}

View File

@@ -14,8 +14,8 @@
<PackageReference Include="NUnit" Version="3.10.1" />
<PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@@ -408,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
// Stub out the connection to avoid a 30second timeout while attempting to connect.
// The tests don't need any connection context anyhow so this doesn't impact the scenario
mockConnectionOpener.Setup(b => b.OpenSqlConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
mockConnectionOpener.Setup(b => b.OpenServerConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
.Throws<Exception>();
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));

View File

@@ -173,7 +173,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead)
{
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead));
return mockFactory.Object;
@@ -184,7 +184,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
// Create a connection info and add the default connection to it
ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead);
ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails);
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null);
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null, null);
return ci;
}

View File

@@ -427,7 +427,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbConnection GetConnection(ConnectionInfo info)
{
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
}
[SuppressMessage("ReSharper", "UnusedParameter.Local")]

View File

@@ -386,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query)
{
var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand();
command.CommandText = query;
return command.ExecuteReader();

View File

@@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
private static DbDataReader GetReader(TestResultSet[] dataSet, string query)
{
var info = Common.CreateTestConnectionInfo(dataSet, false, false);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand();
command.CommandText = query;
return command.ExecuteReader();

View File

@@ -8,6 +8,8 @@ using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
@@ -273,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
/// </summary>
public class TestSqlConnectionFactory : ISqlConnectionFactory
{
public DbConnection CreateSqlConnection(string connectionString)
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
{
return new TestSqlConnection(null)
{