Add support for Azure Active Directory connections (#727)

This commit is contained in:
Matt Irvine
2018-11-13 11:50:30 -08:00
committed by GitHub
parent 2cb7f682c5
commit 7f28f249de
32 changed files with 291 additions and 121 deletions

View File

@@ -143,8 +143,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
if (connInfo != null) if (connInfo != null)
{ {
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo); var serverConnection = ConnectionService.OpenServerConnection(connInfo);
var serverConnection = new ServerConnection(sqlConnection);
var fetcher = new JobFetcher(serverConnection); var fetcher = new JobFetcher(serverConnection);
var filter = new JobActivityFilter(); var filter = new JobActivityFilter();
var jobs = fetcher.FetchJobs(filter); var jobs = fetcher.FetchJobs(filter);
@@ -158,7 +157,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
} }
result.Success = true; result.Success = true;
result.Jobs = agentJobs.ToArray(); result.Jobs = agentJobs.ToArray();
sqlConnection.Close(); serverConnection.SqlConnectionObject.Close();
} }
await requestContext.SendResult(result); await requestContext.SendResult(result);
} }
@@ -269,8 +268,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
out connInfo); out connInfo);
if (connInfo != null) if (connInfo != null)
{ {
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo); var serverConnection = ConnectionService.OpenServerConnection(connInfo);
var serverConnection = new ServerConnection(sqlConnection);
var jobHelper = new JobHelper(serverConnection); var jobHelper = new JobHelper(serverConnection);
jobHelper.JobName = parameters.JobName; jobHelper.JobName = parameters.JobName;
switch(parameters.Action) switch(parameters.Action)
@@ -1163,8 +1161,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
private Tuple<SqlConnectionInfo, DataTable, ServerConnection> CreateSqlConnection(ConnectionInfo connInfo, String jobId) private Tuple<SqlConnectionInfo, DataTable, ServerConnection> CreateSqlConnection(ConnectionInfo connInfo, String jobId)
{ {
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo); var serverConnection = ConnectionService.OpenServerConnection(connInfo);
var serverConnection = new ServerConnection(sqlConnection);
var server = new Server(serverConnection); var server = new Server(serverConnection);
var filter = new JobHistoryFilter(); var filter = new JobHistoryFilter();
filter.JobID = new Guid(jobId); filter.JobID = new Guid(jobId);

View File

@@ -48,8 +48,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ValueType = ConnectionOption.ValueTypeCategory, ValueType = ConnectionOption.ValueTypeCategory,
SpecialValueType = ConnectionOption.SpecialValueAuthType, SpecialValueType = ConnectionOption.SpecialValueAuthType,
CategoryValues = new CategoryValue[] CategoryValues = new CategoryValue[]
{ new CategoryValue {DisplayName = "SQL Login", Name = "SqlLogin" }, { new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
new CategoryValue {DisplayName = "Windows Authentication", Name= "Integrated" } new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
}, },
IsIdentity = true, IsIdentity = true,
IsRequired = true, IsRequired = true,

View File

@@ -523,7 +523,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
// create a sql connection instance // create a sql connection instance
connection = connectionInfo.Factory.CreateSqlConnection(connectionString); connection = connectionInfo.Factory.CreateSqlConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken);
connectionInfo.AddConnection(connectionParams.Type, connection); connectionInfo.AddConnection(connectionParams.Type, connection);
// Add a cancellation token source so that the connection OpenAsync() can be cancelled // Add a cancellation token source so that the connection OpenAsync() can be cancelled
@@ -909,7 +909,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// Connect to master and query sys.databases // Connect to master and query sys.databases
connectionDetails.DatabaseName = "master"; connectionDetails.DatabaseName = "master";
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails)); var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken);
connection.Open(); connection.Open();
List<string> results = new List<string>(); List<string> results = new List<string>();
@@ -1151,6 +1151,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
break; break;
case "SqlLogin": case "SqlLogin":
break; break;
case "AzureMFA":
connectionBuilder.UserID = "";
connectionBuilder.Password = "";
break;
default: default:
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType)); throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
} }
@@ -1387,7 +1391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
string connectionString = BuildConnectionString(info.ConnectionDetails); string connectionString = BuildConnectionString(info.ConnectionDetails);
// create a sql connection instance // create a sql connection instance
DbConnection connection = info.Factory.CreateSqlConnection(connectionString); DbConnection connection = info.Factory.CreateSqlConnection(connectionString, info.ConnectionDetails.AzureAccountToken);
connection.Open(); connection.Open();
info.AddConnection(key, connection); info.AddConnection(key, connection);
} }
@@ -1488,6 +1492,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// Note: we need to audit all uses of this method to determine why we're /// Note: we need to audit all uses of this method to determine why we're
/// bypassing normal ConnectionService connection management /// bypassing normal ConnectionService connection management
/// </summary> /// </summary>
/// <param name="connInfo">The connection info to connect with</param>
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
/// <returns>A SqlConnection created with the given connection info</returns>
internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null) internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null)
{ {
try try
@@ -1515,6 +1522,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// open a dedicated binding server connection // open a dedicated binding server connection
SqlConnection sqlConn = new SqlConnection(connectionString); SqlConnection sqlConn = new SqlConnection(connectionString);
// Fill in Azure authentication token if needed
if (connInfo.ConnectionDetails.AzureAccountToken != null)
{
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
}
sqlConn.Open(); sqlConn.Open();
return sqlConn; return sqlConn;
} }
@@ -1529,6 +1543,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return null; return null;
} }
/// <summary>
/// Create and open a new ServerConnection from a ConnectionInfo object.
/// This calls ConnectionService.OpenSqlConnection and then creates a
/// ServerConnection from it.
/// </summary>
/// <param name="connInfo">The connection info to connect with</param>
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
/// <returns>A ServerConnection (wrapping a SqlConnection) created with the given connection info</returns>
internal static ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName = null)
{
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
ServerConnection serverConnection;
if (connInfo.ConnectionDetails.AzureAccountToken != null)
{
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
}
else
{
serverConnection = new ServerConnection(sqlConnection);
}
return serverConnection;
}
public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false) public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false)
{ {
// verify that the connection is open // verify that the connection is open
@@ -1552,4 +1590,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
} }
} }
} }
public class AzureAccessToken : IRenewableToken
{
public DateTimeOffset TokenExpiry { get; set; }
public string Resource { get; set; }
public string Tenant { get; set; }
public string UserId { get; set; }
private string accessToken;
public AzureAccessToken(string accessToken)
{
this.accessToken = accessToken;
}
public string GetAccessToken()
{
return this.accessToken;
}
}
} }

View File

@@ -497,6 +497,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
} }
} }
public string AzureAccountToken
{
get
{
return GetOptionValue<string>("azureAccountToken");
}
set
{
SetOptionValue("azureAccountToken", value);
}
}
public bool IsComparableTo(ConnectionDetails other) public bool IsComparableTo(ConnectionDetails other)
{ {
if (other == null) if (other == null)
@@ -506,7 +518,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
if (ServerName != other.ServerName if (ServerName != other.ServerName
|| AuthenticationType != other.AuthenticationType || AuthenticationType != other.AuthenticationType
|| UserName != other.UserName) || UserName != other.UserName
|| AzureAccountToken != other.AzureAccountToken)
{ {
return false; return false;
} }

View File

@@ -44,7 +44,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
PacketSize = details.PacketSize, PacketSize = details.PacketSize,
TypeSystemVersion = details.TypeSystemVersion, TypeSystemVersion = details.TypeSystemVersion,
ConnectionString = details.ConnectionString, ConnectionString = details.ConnectionString,
Port = details.Port Port = details.Port,
AzureAccountToken = details.AzureAccountToken
}; };
} }
} }

View File

@@ -15,6 +15,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// <summary> /// <summary>
/// Create a new SQL Connection object /// Create a new SQL Connection object
/// </summary> /// </summary>
DbConnection CreateSqlConnection(string connectionString); DbConnection CreateSqlConnection(string connectionString, string azureAccountToken);
} }
} }

View File

@@ -46,10 +46,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// Opens the connection and sets the lock/command timeout and pooling=false. /// Opens the connection and sets the lock/command timeout and pooling=false.
/// </summary> /// </summary>
/// <returns>The opened connection</returns> /// <returns>The opened connection</returns>
public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry) public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry, string azureAccountToken)
{ {
csb.Pooling = false; csb.Pooling = false;
return OpenConnection(csb.ToString(), useRetry); return OpenConnection(csb.ToString(), useRetry, azureAccountToken);
} }
/// <summary> /// <summary>
@@ -57,7 +57,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// will assert if pooling!=false. /// will assert if pooling!=false.
/// </summary> /// </summary>
/// <returns>The opened connection</returns> /// <returns>The opened connection</returns>
public static IDbConnection OpenConnection(string connectionString, bool useRetry) public static IDbConnection OpenConnection(string connectionString, bool useRetry, string azureAccountToken)
{ {
#if DEBUG #if DEBUG
try try
@@ -88,7 +88,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy(); connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy();
} }
ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy); ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
try try
{ {
@@ -136,7 +136,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
SqlConnectionStringBuilder csb, SqlConnectionStringBuilder csb,
Action<IDbConnection> usingConnection, Action<IDbConnection> usingConnection,
Predicate<Exception> catchException, Predicate<Exception> catchException,
bool useRetry) bool useRetry,
string azureAccountToken)
{ {
Validate.IsNotNull(nameof(csb), csb); Validate.IsNotNull(nameof(csb), csb);
Validate.IsNotNull(nameof(usingConnection), usingConnection); Validate.IsNotNull(nameof(usingConnection), usingConnection);
@@ -145,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
{ {
// Always disable pooling // Always disable pooling
csb.Pooling = false; csb.Pooling = false;
using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry)) using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry, azureAccountToken))
{ {
usingConnection(conn); usingConnection(conn);
} }
@@ -228,7 +229,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
string commandText, string commandText,
Action<IDbCommand> initializeCommand, Action<IDbCommand> initializeCommand,
Predicate<Exception> catchException, Predicate<Exception> catchException,
bool useRetry) bool useRetry,
string azureAccountToken)
{ {
object retObject = null; object retObject = null;
OpenConnection( OpenConnection(
@@ -238,7 +240,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException); retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException);
}, },
catchException, catchException,
useRetry); useRetry,
azureAccountToken);
return retObject; return retObject;
} }
@@ -636,7 +639,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// <summary> /// <summary>
/// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler. /// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler.
/// </summary> /// </summary>
public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder) public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder, string azureAccountToken)
{ {
Validate.IsNotNull(nameof(builder), builder); Validate.IsNotNull(nameof(builder), builder);
@@ -670,7 +673,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
Logger.Write(TraceEventType.Error, ex.ToString()); Logger.Write(TraceEventType.Error, ex.ToString());
return StandardExceptionHandler(ex); // handled return StandardExceptionHandler(ex); // handled
}, },
useRetry: true); useRetry: true,
azureAccountToken: azureAccountToken);
return isDatabaseReadOnly; return isDatabaseReadOnly;
} }
@@ -697,7 +701,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
public string MachineName; public string MachineName;
} }
public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo) public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo, string azureAccountToken)
{ {
serverInfo = null; serverInfo = null;
if (!TryGetConnectionStringBuilder(connectionString, out SqlConnectionStringBuilder builder)) if (!TryGetConnectionStringBuilder(connectionString, out SqlConnectionStringBuilder builder))
@@ -705,14 +709,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
return false; return false;
} }
serverInfo = GetServerVersion(builder); serverInfo = GetServerVersion(builder, azureAccountToken);
return true; return true;
} }
/// <summary> /// <summary>
/// Returns the version of the server. This routine will throw if an exception is encountered. /// Returns the version of the server. This routine will throw if an exception is encountered.
/// </summary> /// </summary>
public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb) public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb, string azureAccountToken)
{ {
Validate.IsNotNull(nameof(csb), csb); Validate.IsNotNull(nameof(csb), csb);
ServerInfo serverInfo = null; ServerInfo serverInfo = null;
@@ -724,7 +728,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
serverInfo = GetServerVersion(connection); serverInfo = GetServerVersion(connection);
}, },
catchException: null, // Always throw catchException: null, // Always throw
useRetry: true); useRetry: true,
azureAccountToken: azureAccountToken);
return serverInfo; return serverInfo;
} }
@@ -1057,7 +1062,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// Returns true if the authenticating database is master, otherwise false. An example of /// Returns true if the authenticating database is master, otherwise false. An example of
/// false is when the user is a contained user connecting to a contained database. /// false is when the user is a contained user connecting to a contained database.
/// </summary> /// </summary>
public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder) public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder, string azureAccountToken)
{ {
bool authIsMaster = true; bool authIsMaster = true;
OpenConnection( OpenConnection(
@@ -1067,7 +1072,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
authIsMaster = IsAuthenticatingDatabaseMaster(connection); authIsMaster = IsAuthenticatingDatabaseMaster(connection);
}, },
catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception
useRetry: true); useRetry: true,
azureAccountToken: azureAccountToken);
return authIsMaster; return authIsMaster;
} }

View File

@@ -59,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
/// <param name="connectionString">The connection string used to open the SQL Azure database.</param> /// <param name="connectionString">The connection string used to open the SQL Azure database.</param>
/// <param name="connectionRetryPolicy">The retry policy defining whether to retry a request if a connection fails to be established.</param> /// <param name="connectionRetryPolicy">The retry policy defining whether to retry a request if a connection fails to be established.</param>
/// <param name="commandRetryPolicy">The retry policy defining whether to retry a request if a command fails to be executed.</param> /// <param name="commandRetryPolicy">The retry policy defining whether to retry a request if a command fails to be executed.</param>
public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy) public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy, string azureAccountToken)
{ {
_underlyingConnection = new SqlConnection(connectionString); _underlyingConnection = new SqlConnection(connectionString);
_connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy(); _connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy();
@@ -68,6 +68,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
_underlyingConnection.StateChange += OnConnectionStateChange; _underlyingConnection.StateChange += OnConnectionStateChange;
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback; _connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
_commandRetryPolicy.RetryOccurred += RetryCommandCallback; _commandRetryPolicy.RetryOccurred += RetryCommandCallback;
if (azureAccountToken != null)
{
_underlyingConnection.AccessToken = azureAccountToken;
}
} }
/// <summary> /// <summary>

View File

@@ -18,11 +18,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// <summary> /// <summary>
/// Creates a new SqlConnection object /// Creates a new SqlConnection object
/// </summary> /// </summary>
public DbConnection CreateSqlConnection(string connectionString) public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
{ {
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy); return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
} }
} }
} }

View File

@@ -213,8 +213,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
if (connInfo != null) if (connInfo != null)
{ {
SqlConnection connection = ConnectionService.OpenSqlConnection(connInfo, "Restore"); Server server = new Server(ConnectionService.OpenServerConnection(connInfo, "Restore"));
Server server = new Server(new ServerConnection(connection));
RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, targetDatabaseName); RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, targetDatabaseName);
return restoreDataObject; return restoreDataObject;

View File

@@ -9,6 +9,7 @@ using System.Data.Common;
using System.Data.SqlClient; using System.Data.SqlClient;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters; using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
@@ -56,7 +57,16 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
} }
// Connect with SMO and get the metadata for the table // Connect with SMO and get the metadata for the table
Server server = new Server(new ServerConnection(sqlConn)); ServerConnection serverConnection;
if (sqlConn.AccessToken == null)
{
serverConnection = new ServerConnection(sqlConn);
}
else
{
serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken));
}
Server server = new Server(serverConnection);
Database db = new Database(server, sqlConn.Database); Database db = new Database(server, sqlConn.Database);
TableViewTableTypeBase smoResult; TableViewTableTypeBase smoResult;

View File

@@ -37,9 +37,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
/// <summary> /// <summary>
/// Virtual method used to support mocking and testing /// Virtual method used to support mocking and testing
/// </summary> /// </summary>
public virtual SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName) public virtual ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName)
{ {
return ConnectionService.OpenSqlConnection(connInfo, featureName); return ConnectionService.OpenServerConnection(connInfo, featureName);
} }
} }
@@ -198,10 +198,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
try try
{ {
bindingContext.BindingLock.Reset(); bindingContext.BindingLock.Reset();
SqlConnection sqlConn = connectionOpener.OpenSqlConnection(connInfo, featureName);
// populate the binding context to work with the SMO metadata provider // populate the binding context to work with the SMO metadata provider
bindingContext.ServerConnection = new ServerConnection(sqlConn); bindingContext.ServerConnection = connectionOpener.OpenServerConnection(connInfo, featureName);
if (this.needsMetadata) if (this.needsMetadata)
{ {

View File

@@ -682,7 +682,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
/// <param name="userName">User name for not trusted connections</param> /// <param name="userName">User name for not trusted connections</param>
/// <param name="password">Password for not trusted connections</param> /// <param name="password">Password for not trusted connections</param>
/// <param name="xmlParameters">XML string with parameters</param> /// <param name="xmlParameters">XML string with parameters</param>
public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters) public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters, string azureAccountToken = null)
{ {
this.serverType = serverType; this.serverType = serverType;
this.serverName = serverName; this.serverName = serverName;
@@ -690,7 +690,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
if (serverType == ServerType.SQL) if (serverType == ServerType.SQL)
{ {
// does some extra initialization // does some extra initialization
ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName), true); ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName, azureAccountToken), true);
// NOTE: ServerConnection property will constuct the object if needed // NOTE: ServerConnection property will constuct the object if needed
m_server = new Server(ServerConnection); m_server = new Server(ServerConnection);
@@ -1024,7 +1024,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
bool trusted, bool trusted,
string userName, string userName,
SecureString password, SecureString password,
string databaseName) string databaseName,
string azureAccountToken)
{ {
SqlConnectionInfoWithConnection tempCI = new SqlConnectionInfoWithConnection(serverName); SqlConnectionInfoWithConnection tempCI = new SqlConnectionInfoWithConnection(serverName);
tempCI.SingleConnection = false; tempCI.SingleConnection = false;
@@ -1040,6 +1041,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
tempCI.UserName = userName; tempCI.UserName = userName;
tempCI.SecurePassword = password; tempCI.SecurePassword = password;
} }
tempCI.DatabaseName = databaseName; tempCI.DatabaseName = databaseName;
return tempCI; return tempCI;
@@ -1220,39 +1222,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
containerDoc = CreateDataContainerDocument(connInfo, databaseExists); containerDoc = CreateDataContainerDocument(connInfo, databaseExists);
} }
CDataContainer dataContainer; var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
// add alternate port to server name property if provided var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
var connectionDetails = connInfo.ConnectionDetails; connectionInfoWithConnection.ServerConnection = serverConnection;
string serverName = !connectionDetails.Port.HasValue CDataContainer dataContainer = new CDataContainer(ServerType.SQL, connectionInfoWithConnection, true);
? connectionDetails.ServerName dataContainer.Init(containerDoc);
: string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value);
// check if the connection is using SQL Auth or Integrated Auth
// TODO: ConnectionQueue try to get an existing connection (ConnectionQueue)
if (string.Equals(connectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase))
{
var passwordSecureString = BuildSecureStringFromPassword(connectionDetails.Password);
dataContainer = new CDataContainer(
CDataContainer.ServerType.SQL,
serverName,
false,
connectionDetails.UserName,
passwordSecureString,
connectionDetails.DatabaseName,
containerDoc.InnerXml);
}
else
{
dataContainer = new CDataContainer(
CDataContainer.ServerType.SQL,
serverName,
true,
null,
null,
connectionDetails.DatabaseName,
containerDoc.InnerXml);
}
return dataContainer; return dataContainer;
} }

View File

@@ -9,6 +9,7 @@ using System.Data.Common;
using System.Data.SqlClient; using System.Data.SqlClient;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters; using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
@@ -60,7 +61,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata
} }
// Connect with SMO and get the metadata for the table // Connect with SMO and get the metadata for the table
Server server = new Server(new ServerConnection(sqlConn)); ServerConnection serverConnection;
if (sqlConn.AccessToken == null)
{
serverConnection = new ServerConnection(sqlConn);
}
else
{
serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken));
}
Server server = new Server(serverConnection);
Database database = server.Databases[sqlConn.Database]; Database database = server.Databases[sqlConn.Database];
TableViewTableTypeBase smoResult; TableViewTableTypeBase smoResult;
switch (objectType.ToLowerInvariant()) switch (objectType.ToLowerInvariant())

View File

@@ -21,9 +21,9 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" /> <PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" /> <PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" /> <PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<Compile Include="**\*.cs" /> <Compile Include="**\*.cs" />

View File

@@ -6,6 +6,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Data.SqlClient; using System.Data.SqlClient;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
@@ -42,10 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
ServerConnection = serverConnection; ServerConnection = serverConnection;
} }
public ScriptAsScriptingOperation(ScriptingParams parameters) : base(parameters) public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccountToken) : base(parameters)
{ {
SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString); SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString);
if (azureAccountToken != null)
{
sqlConnection.AccessToken = azureAccountToken;
}
ServerConnection = new ServerConnection(sqlConnection); ServerConnection = new ServerConnection(sqlConnection);
if (azureAccountToken != null)
{
ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken);
}
disconnectAtDispose = true; disconnectAtDispose = true;
} }

View File

@@ -26,8 +26,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
private int eventSequenceNumber = 1; private int eventSequenceNumber = 1;
public ScriptingScriptOperation(ScriptingParams parameters): base(parameters) private string azureAccessToken;
public ScriptingScriptOperation(ScriptingParams parameters, string azureAccessToken): base(parameters)
{ {
this.azureAccessToken = azureAccessToken;
} }
public override void Execute() public override void Execute()
@@ -200,7 +203,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
selectedObjects.Count(), selectedObjects.Count(),
string.Join(", ", selectedObjects))); string.Join(", ", selectedObjects)));
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString); string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString, this.azureAccessToken);
string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog; string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog;
foreach (ScriptingObject scriptingObject in selectedObjects) foreach (ScriptingObject scriptingObject in selectedObjects)

View File

@@ -111,12 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
// use the owner uri property to lookup its associated ConnectionInfo // use the owner uri property to lookup its associated ConnectionInfo
// and then build a connection string out of that // and then build a connection string out of that
ConnectionInfo connInfo = null; ConnectionInfo connInfo = null;
string accessToken = null;
if (parameters.ConnectionString == null) if (parameters.ConnectionString == null)
{ {
ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo); ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
if (connInfo != null) if (connInfo != null)
{ {
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
} }
else else
{ {
@@ -126,11 +128,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
if (!ShouldCreateScriptAsOperation(parameters)) if (!ShouldCreateScriptAsOperation(parameters))
{ {
operation = new ScriptingScriptOperation(parameters); operation = new ScriptingScriptOperation(parameters, accessToken);
} }
else else
{ {
operation = new ScriptAsScriptingOperation(parameters); operation = new ScriptAsScriptingOperation(parameters, accessToken);
} }
operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait(); operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait();

View File

@@ -4,6 +4,7 @@
// //
using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Utility;
using System; using System;
@@ -71,17 +72,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
parameters.OperationId = this.OperationId; parameters.OperationId = this.OperationId;
} }
protected string GetServerNameFromLiveInstance(string connectionString) protected string GetServerNameFromLiveInstance(string connectionString, string azureAccessToken)
{ {
string serverName = null; string serverName = null;
using (SqlConnection connection = new SqlConnection(connectionString)) using (SqlConnection connection = new SqlConnection(connectionString))
{ {
if (azureAccessToken != null)
{
connection.AccessToken = azureAccessToken;
}
connection.Open(); connection.Open();
try try
{ {
ServerConnection serverConnection;
ServerConnection serverConnection = new ServerConnection(connection); if (azureAccessToken == null)
{
serverConnection = new ServerConnection(connection);
}
else
{
serverConnection = new ServerConnection(connection, new AzureAccessToken(azureAccessToken));
}
serverName = serverConnection.TrueName; serverName = serverConnection.TrueName;
} }
catch (SqlException e) catch (SqlException e)

View File

@@ -258,7 +258,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy); ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy, null);
return connection; return connection;
} }
@@ -283,7 +283,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn); logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn);
}, },
catchException: null, catchException: null,
useRetry: false); useRetry: false,
azureAccountToken: null);
Assert.False(string.IsNullOrWhiteSpace(filePath)); Assert.False(string.IsNullOrWhiteSpace(filePath));
Assert.False(string.IsNullOrWhiteSpace(logPath)); Assert.False(string.IsNullOrWhiteSpace(logPath));
@@ -342,7 +343,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
var connectionBuilder = CreateTestConnectionStringBuilder(); var connectionBuilder = CreateTestConnectionStringBuilder();
Assert.NotNull(connectionBuilder); Assert.NotNull(connectionBuilder);
bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder); bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder, null);
Assert.False(isReadOnly); Assert.False(isReadOnly);
} }
@@ -352,7 +353,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
[Fact] [Fact]
public void TestIsDatabaseReadonlyWithNullBuilder() public void TestIsDatabaseReadonlyWithNullBuilder()
{ {
Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null)); Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null, null));
} }
/// <summary> /// <summary>
@@ -361,7 +362,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
[Fact] [Fact]
public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed() public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed()
{ {
using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true)) using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true, azureAccountToken: null))
{ {
VerifySessionSettings(conn, true); VerifySessionSettings(conn, true);
VerifySessionSettings(conn, false); VerifySessionSettings(conn, false);
@@ -506,7 +507,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
"SET NOCOUNT ON; SET NOCOUNT OFF;", "SET NOCOUNT ON; SET NOCOUNT OFF;",
ReliableConnectionHelper.SetCommandTimeout, ReliableConnectionHelper.SetCommandTimeout,
null, null,
true true,
null
); );
Assert.NotNull(result); Assert.NotNull(result);
} }
@@ -519,7 +521,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
{ {
ReliableConnectionHelper.ServerInfo info = null; ReliableConnectionHelper.ServerInfo info = null;
var connBuilder = CreateTestConnectionStringBuilder(); var connBuilder = CreateTestConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info)); Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info, null));
Assert.NotNull(info); Assert.NotNull(info);
Assert.NotNull(info.ServerVersion); Assert.NotNull(info.ServerVersion);
@@ -535,7 +537,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
RunIfWrapper.RunIfWindows(() => RunIfWrapper.RunIfWindows(() =>
{ {
ReliableConnectionHelper.ServerInfo info = null; ReliableConnectionHelper.ServerInfo info = null;
Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info)); Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info, null));
}); });
} }
@@ -679,12 +681,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath); var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath);
ConnectionInfo connInfo = result.ConnectionInfo; ConnectionInfo connInfo = result.ConnectionInfo;
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default]; DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
connection.Open();
Assert.True(connection.State == ConnectionState.Open, "Connection should be open."); Assert.True(connection.State == ConnectionState.Open, "Connection should be open.");
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection)); Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder)); Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder, null));
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder()); ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
Assert.NotNull(ReliableConnectionHelper.GetServerName(connection)); Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));

View File

@@ -33,7 +33,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" /> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" /> <PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" /> <PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" /> <PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" /> <PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -12,7 +12,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" /> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" /> <PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" /> <PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" /> <PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" /> <PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -12,7 +12,7 @@
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" /> <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
<PackageReference Include="xunit" Version="2.2.0" /> <PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" /> <PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" /> <PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" /> <PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -80,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
}); });
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object); .Returns(mockConnection.Object);
@@ -146,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
.Returns(() => Task.Run(() => {})); .Returns(() => Task.Run(() => {}));
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object) .Returns(mockConnection.Object)
.Returns(mockConnection2.Object); .Returns(mockConnection2.Object);
@@ -209,7 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
}); });
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(mockConnection.Object); .Returns(mockConnection.Object);
@@ -282,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
connectionMock.Setup(c => c.Database).Returns(expectedDbName); connectionMock.Setup(c => c.Database).Returns(expectedDbName);
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(connectionMock.Object); .Returns(connectionMock.Object);
var connectionService = new ConnectionService(mockFactory.Object); var connectionService = new ConnectionService(mockFactory.Object);
@@ -321,8 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var dummySqlConnection = new TestSqlConnection(null); var dummySqlConnection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString) => .Returns((string connString, string azureAccountToken) =>
{ {
dummySqlConnection.ConnectionString = connString; dummySqlConnection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -775,7 +775,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// Setup mock connection factory to inject query results // Setup mock connection factory to inject query results
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(CreateMockDbConnection(new[] {data})); .Returns(CreateMockDbConnection(new[] {data}));
var connectionService = new ConnectionService(mockFactory.Object); var connectionService = new ConnectionService(mockFactory.Object);
@@ -1324,8 +1324,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = new TestSqlConnection(null); var connection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString) => .Returns((string connString, string azureAccountToken) =>
{ {
connection.ConnectionString = connString; connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1374,8 +1374,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object; var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString) => .Returns((string connString, string azureAccountToken) =>
{ {
connection.ConnectionString = connString; connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1427,8 +1427,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connection = mockConnection.Object; var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string connString) => .Returns((string connString, string azureAccountToken) =>
{ {
connection.ConnectionString = connString; connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
@@ -1484,5 +1484,32 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.Equal(false, details.TrustServerCertificate); Assert.Equal(false, details.TrustServerCertificate);
Assert.Equal(30, details.ConnectTimeout); Assert.Equal(30, details.ConnectTimeout);
} }
[Fact]
public async void ConnectingWithAzureAccountUsesToken()
{
// Set up mock connection factory
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(new TestSqlConnection(null));
var connectionService = new ConnectionService(mockFactory.Object);
var details = TestObjects.GetTestConnectionDetails();
var azureAccountToken = "testAzureAccountToken";
details.AzureAccountToken = azureAccountToken;
details.UserName = "";
details.Password = "";
details.AuthenticationType = "AzureMFA";
// If I open a connection using connection details that include an account token
await connectionService.Connect(new ConnectParams
{
OwnerUri = "testURI",
Connection = details
});
// Then the connection factory got called with details including an account token
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
}
} }
} }

View File

@@ -0,0 +1,36 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{
/// <summary>
/// Tests for ReliableConnection code
/// </summary>
public class ReliableConnectionTests
{
[Fact]
public void ReliableSqlConnectionUsesAzureToken()
{
ConnectionDetails details = TestObjects.GetTestConnectionDetails();
details.UserName = "";
details.Password = "";
string connectionString = ConnectionService.BuildConnectionString(details);
string azureAccountToken = "testAzureAccountToken";
RetryPolicy retryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
// If I create a ReliableSqlConnection using an azure account token
var reliableConnection = new ReliableSqlConnection(connectionString, retryPolicy, retryPolicy, azureAccountToken);
// Then the connection's azureAccountToken gets set
Assert.Equal(azureAccountToken, reliableConnection.GetUnderlyingConnection().AccessToken);
}
}
}

View File

@@ -14,8 +14,8 @@
<PackageReference Include="NUnit" Version="3.10.1" /> <PackageReference Include="NUnit" Version="3.10.1" />
<PackageReference Include="xunit" Version="2.2.0" /> <PackageReference Include="xunit" Version="2.2.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" /> <PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" /> <PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" /> <PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" /> <PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -408,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
// Stub out the connection to avoid a 30second timeout while attempting to connect. // Stub out the connection to avoid a 30second timeout while attempting to connect.
// The tests don't need any connection context anyhow so this doesn't impact the scenario // The tests don't need any connection context anyhow so this doesn't impact the scenario
mockConnectionOpener.Setup(b => b.OpenSqlConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>())) mockConnectionOpener.Setup(b => b.OpenServerConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
.Throws<Exception>(); .Throws<Exception>();
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>())) connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details))); .Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));

View File

@@ -173,7 +173,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead) private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead)
{ {
var mockFactory = new Mock<ISqlConnectionFactory>(); var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>())) mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead)); .Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead));
return mockFactory.Object; return mockFactory.Object;
@@ -184,7 +184,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
// Create a connection info and add the default connection to it // Create a connection info and add the default connection to it
ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead); ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead);
ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails); ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails);
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null); ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null, null);
return ci; return ci;
} }

View File

@@ -427,7 +427,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbConnection GetConnection(ConnectionInfo info) private static DbConnection GetConnection(ConnectionInfo info)
{ {
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
} }
[SuppressMessage("ReSharper", "UnusedParameter.Local")] [SuppressMessage("ReSharper", "UnusedParameter.Local")]

View File

@@ -386,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query) private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query)
{ {
var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead); var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand(); var command = connection.CreateCommand();
command.CommandText = query; command.CommandText = query;
return command.ExecuteReader(); return command.ExecuteReader();

View File

@@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
private static DbDataReader GetReader(TestResultSet[] dataSet, string query) private static DbDataReader GetReader(TestResultSet[] dataSet, string query)
{ {
var info = Common.CreateTestConnectionInfo(dataSet, false, false); var info = Common.CreateTestConnectionInfo(dataSet, false, false);
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
var command = connection.CreateCommand(); var command = connection.CreateCommand();
command.CommandText = query; command.CommandText = query;
return command.ExecuteReader(); return command.ExecuteReader();

View File

@@ -8,6 +8,8 @@ using System.Collections.Generic;
using System.Data; using System.Data;
using System.Data.Common; using System.Data.Common;
using System.Linq; using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.LanguageServices;
@@ -273,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
/// </summary> /// </summary>
public class TestSqlConnectionFactory : ISqlConnectionFactory public class TestSqlConnectionFactory : ISqlConnectionFactory
{ {
public DbConnection CreateSqlConnection(string connectionString) public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
{ {
return new TestSqlConnection(null) return new TestSqlConnection(null)
{ {