mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-02-16 10:58:30 -05:00
Add support for Azure Active Directory connections (#727)
This commit is contained in:
@@ -143,8 +143,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
|
|||||||
|
|
||||||
if (connInfo != null)
|
if (connInfo != null)
|
||||||
{
|
{
|
||||||
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo);
|
var serverConnection = ConnectionService.OpenServerConnection(connInfo);
|
||||||
var serverConnection = new ServerConnection(sqlConnection);
|
|
||||||
var fetcher = new JobFetcher(serverConnection);
|
var fetcher = new JobFetcher(serverConnection);
|
||||||
var filter = new JobActivityFilter();
|
var filter = new JobActivityFilter();
|
||||||
var jobs = fetcher.FetchJobs(filter);
|
var jobs = fetcher.FetchJobs(filter);
|
||||||
@@ -158,7 +157,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
|
|||||||
}
|
}
|
||||||
result.Success = true;
|
result.Success = true;
|
||||||
result.Jobs = agentJobs.ToArray();
|
result.Jobs = agentJobs.ToArray();
|
||||||
sqlConnection.Close();
|
serverConnection.SqlConnectionObject.Close();
|
||||||
}
|
}
|
||||||
await requestContext.SendResult(result);
|
await requestContext.SendResult(result);
|
||||||
}
|
}
|
||||||
@@ -269,8 +268,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
|
|||||||
out connInfo);
|
out connInfo);
|
||||||
if (connInfo != null)
|
if (connInfo != null)
|
||||||
{
|
{
|
||||||
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo);
|
var serverConnection = ConnectionService.OpenServerConnection(connInfo);
|
||||||
var serverConnection = new ServerConnection(sqlConnection);
|
|
||||||
var jobHelper = new JobHelper(serverConnection);
|
var jobHelper = new JobHelper(serverConnection);
|
||||||
jobHelper.JobName = parameters.JobName;
|
jobHelper.JobName = parameters.JobName;
|
||||||
switch(parameters.Action)
|
switch(parameters.Action)
|
||||||
@@ -1163,8 +1161,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent
|
|||||||
|
|
||||||
private Tuple<SqlConnectionInfo, DataTable, ServerConnection> CreateSqlConnection(ConnectionInfo connInfo, String jobId)
|
private Tuple<SqlConnectionInfo, DataTable, ServerConnection> CreateSqlConnection(ConnectionInfo connInfo, String jobId)
|
||||||
{
|
{
|
||||||
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo);
|
var serverConnection = ConnectionService.OpenServerConnection(connInfo);
|
||||||
var serverConnection = new ServerConnection(sqlConnection);
|
|
||||||
var server = new Server(serverConnection);
|
var server = new Server(serverConnection);
|
||||||
var filter = new JobHistoryFilter();
|
var filter = new JobHistoryFilter();
|
||||||
filter.JobID = new Guid(jobId);
|
filter.JobID = new Guid(jobId);
|
||||||
|
|||||||
@@ -48,8 +48,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
ValueType = ConnectionOption.ValueTypeCategory,
|
ValueType = ConnectionOption.ValueTypeCategory,
|
||||||
SpecialValueType = ConnectionOption.SpecialValueAuthType,
|
SpecialValueType = ConnectionOption.SpecialValueAuthType,
|
||||||
CategoryValues = new CategoryValue[]
|
CategoryValues = new CategoryValue[]
|
||||||
{ new CategoryValue {DisplayName = "SQL Login", Name = "SqlLogin" },
|
{ new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" },
|
||||||
new CategoryValue {DisplayName = "Windows Authentication", Name= "Integrated" }
|
new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" },
|
||||||
|
new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" }
|
||||||
},
|
},
|
||||||
IsIdentity = true,
|
IsIdentity = true,
|
||||||
IsRequired = true,
|
IsRequired = true,
|
||||||
|
|||||||
@@ -523,7 +523,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
|
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
|
||||||
|
|
||||||
// create a sql connection instance
|
// create a sql connection instance
|
||||||
connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
|
connection = connectionInfo.Factory.CreateSqlConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken);
|
||||||
connectionInfo.AddConnection(connectionParams.Type, connection);
|
connectionInfo.AddConnection(connectionParams.Type, connection);
|
||||||
|
|
||||||
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
|
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
|
||||||
@@ -909,7 +909,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
|
|
||||||
// Connect to master and query sys.databases
|
// Connect to master and query sys.databases
|
||||||
connectionDetails.DatabaseName = "master";
|
connectionDetails.DatabaseName = "master";
|
||||||
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails));
|
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken);
|
||||||
connection.Open();
|
connection.Open();
|
||||||
|
|
||||||
List<string> results = new List<string>();
|
List<string> results = new List<string>();
|
||||||
@@ -1151,6 +1151,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
break;
|
break;
|
||||||
case "SqlLogin":
|
case "SqlLogin":
|
||||||
break;
|
break;
|
||||||
|
case "AzureMFA":
|
||||||
|
connectionBuilder.UserID = "";
|
||||||
|
connectionBuilder.Password = "";
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
|
throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
|
||||||
}
|
}
|
||||||
@@ -1387,7 +1391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
string connectionString = BuildConnectionString(info.ConnectionDetails);
|
string connectionString = BuildConnectionString(info.ConnectionDetails);
|
||||||
|
|
||||||
// create a sql connection instance
|
// create a sql connection instance
|
||||||
DbConnection connection = info.Factory.CreateSqlConnection(connectionString);
|
DbConnection connection = info.Factory.CreateSqlConnection(connectionString, info.ConnectionDetails.AzureAccountToken);
|
||||||
connection.Open();
|
connection.Open();
|
||||||
info.AddConnection(key, connection);
|
info.AddConnection(key, connection);
|
||||||
}
|
}
|
||||||
@@ -1488,6 +1492,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
/// Note: we need to audit all uses of this method to determine why we're
|
/// Note: we need to audit all uses of this method to determine why we're
|
||||||
/// bypassing normal ConnectionService connection management
|
/// bypassing normal ConnectionService connection management
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
/// <param name="connInfo">The connection info to connect with</param>
|
||||||
|
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
|
||||||
|
/// <returns>A SqlConnection created with the given connection info</returns>
|
||||||
internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null)
|
internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
@@ -1515,6 +1522,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
|
|
||||||
// open a dedicated binding server connection
|
// open a dedicated binding server connection
|
||||||
SqlConnection sqlConn = new SqlConnection(connectionString);
|
SqlConnection sqlConn = new SqlConnection(connectionString);
|
||||||
|
|
||||||
|
// Fill in Azure authentication token if needed
|
||||||
|
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||||
|
{
|
||||||
|
sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||||
|
}
|
||||||
|
|
||||||
sqlConn.Open();
|
sqlConn.Open();
|
||||||
return sqlConn;
|
return sqlConn;
|
||||||
}
|
}
|
||||||
@@ -1529,6 +1543,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Create and open a new ServerConnection from a ConnectionInfo object.
|
||||||
|
/// This calls ConnectionService.OpenSqlConnection and then creates a
|
||||||
|
/// ServerConnection from it.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connInfo">The connection info to connect with</param>
|
||||||
|
/// <param name="featureName">A plaintext string that will be included in the application name for the connection</param>
|
||||||
|
/// <returns>A ServerConnection (wrapping a SqlConnection) created with the given connection info</returns>
|
||||||
|
internal static ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName = null)
|
||||||
|
{
|
||||||
|
var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName);
|
||||||
|
ServerConnection serverConnection;
|
||||||
|
if (connInfo.ConnectionDetails.AzureAccountToken != null)
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(sqlConnection);
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverConnection;
|
||||||
|
}
|
||||||
|
|
||||||
public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false)
|
public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false)
|
||||||
{
|
{
|
||||||
// verify that the connection is open
|
// verify that the connection is open
|
||||||
@@ -1552,4 +1590,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class AzureAccessToken : IRenewableToken
|
||||||
|
{
|
||||||
|
public DateTimeOffset TokenExpiry { get; set; }
|
||||||
|
public string Resource { get; set; }
|
||||||
|
public string Tenant { get; set; }
|
||||||
|
public string UserId { get; set; }
|
||||||
|
|
||||||
|
private string accessToken;
|
||||||
|
|
||||||
|
public AzureAccessToken(string accessToken)
|
||||||
|
{
|
||||||
|
this.accessToken = accessToken;
|
||||||
|
}
|
||||||
|
|
||||||
|
public string GetAccessToken()
|
||||||
|
{
|
||||||
|
return this.accessToken;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -497,6 +497,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public string AzureAccountToken
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
return GetOptionValue<string>("azureAccountToken");
|
||||||
|
}
|
||||||
|
set
|
||||||
|
{
|
||||||
|
SetOptionValue("azureAccountToken", value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public bool IsComparableTo(ConnectionDetails other)
|
public bool IsComparableTo(ConnectionDetails other)
|
||||||
{
|
{
|
||||||
if (other == null)
|
if (other == null)
|
||||||
@@ -506,7 +518,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
|||||||
|
|
||||||
if (ServerName != other.ServerName
|
if (ServerName != other.ServerName
|
||||||
|| AuthenticationType != other.AuthenticationType
|
|| AuthenticationType != other.AuthenticationType
|
||||||
|| UserName != other.UserName)
|
|| UserName != other.UserName
|
||||||
|
|| AzureAccountToken != other.AzureAccountToken)
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
|
|||||||
PacketSize = details.PacketSize,
|
PacketSize = details.PacketSize,
|
||||||
TypeSystemVersion = details.TypeSystemVersion,
|
TypeSystemVersion = details.TypeSystemVersion,
|
||||||
ConnectionString = details.ConnectionString,
|
ConnectionString = details.ConnectionString,
|
||||||
Port = details.Port
|
Port = details.Port,
|
||||||
|
AzureAccountToken = details.AzureAccountToken
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Create a new SQL Connection object
|
/// Create a new SQL Connection object
|
||||||
/// </summary>
|
/// </summary>
|
||||||
DbConnection CreateSqlConnection(string connectionString);
|
DbConnection CreateSqlConnection(string connectionString, string azureAccountToken);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,10 +46,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
/// Opens the connection and sets the lock/command timeout and pooling=false.
|
/// Opens the connection and sets the lock/command timeout and pooling=false.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <returns>The opened connection</returns>
|
/// <returns>The opened connection</returns>
|
||||||
public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry)
|
public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry, string azureAccountToken)
|
||||||
{
|
{
|
||||||
csb.Pooling = false;
|
csb.Pooling = false;
|
||||||
return OpenConnection(csb.ToString(), useRetry);
|
return OpenConnection(csb.ToString(), useRetry, azureAccountToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -57,7 +57,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
/// will assert if pooling!=false.
|
/// will assert if pooling!=false.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <returns>The opened connection</returns>
|
/// <returns>The opened connection</returns>
|
||||||
public static IDbConnection OpenConnection(string connectionString, bool useRetry)
|
public static IDbConnection OpenConnection(string connectionString, bool useRetry, string azureAccountToken)
|
||||||
{
|
{
|
||||||
#if DEBUG
|
#if DEBUG
|
||||||
try
|
try
|
||||||
@@ -88,7 +88,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy();
|
connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy();
|
||||||
}
|
}
|
||||||
|
|
||||||
ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
|
ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
@@ -136,7 +136,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
SqlConnectionStringBuilder csb,
|
SqlConnectionStringBuilder csb,
|
||||||
Action<IDbConnection> usingConnection,
|
Action<IDbConnection> usingConnection,
|
||||||
Predicate<Exception> catchException,
|
Predicate<Exception> catchException,
|
||||||
bool useRetry)
|
bool useRetry,
|
||||||
|
string azureAccountToken)
|
||||||
{
|
{
|
||||||
Validate.IsNotNull(nameof(csb), csb);
|
Validate.IsNotNull(nameof(csb), csb);
|
||||||
Validate.IsNotNull(nameof(usingConnection), usingConnection);
|
Validate.IsNotNull(nameof(usingConnection), usingConnection);
|
||||||
@@ -145,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
{
|
{
|
||||||
// Always disable pooling
|
// Always disable pooling
|
||||||
csb.Pooling = false;
|
csb.Pooling = false;
|
||||||
using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry))
|
using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry, azureAccountToken))
|
||||||
{
|
{
|
||||||
usingConnection(conn);
|
usingConnection(conn);
|
||||||
}
|
}
|
||||||
@@ -228,7 +229,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
string commandText,
|
string commandText,
|
||||||
Action<IDbCommand> initializeCommand,
|
Action<IDbCommand> initializeCommand,
|
||||||
Predicate<Exception> catchException,
|
Predicate<Exception> catchException,
|
||||||
bool useRetry)
|
bool useRetry,
|
||||||
|
string azureAccountToken)
|
||||||
{
|
{
|
||||||
object retObject = null;
|
object retObject = null;
|
||||||
OpenConnection(
|
OpenConnection(
|
||||||
@@ -238,7 +240,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException);
|
retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException);
|
||||||
},
|
},
|
||||||
catchException,
|
catchException,
|
||||||
useRetry);
|
useRetry,
|
||||||
|
azureAccountToken);
|
||||||
|
|
||||||
return retObject;
|
return retObject;
|
||||||
}
|
}
|
||||||
@@ -636,7 +639,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler.
|
/// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder)
|
public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder, string azureAccountToken)
|
||||||
{
|
{
|
||||||
Validate.IsNotNull(nameof(builder), builder);
|
Validate.IsNotNull(nameof(builder), builder);
|
||||||
|
|
||||||
@@ -670,7 +673,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
Logger.Write(TraceEventType.Error, ex.ToString());
|
Logger.Write(TraceEventType.Error, ex.ToString());
|
||||||
return StandardExceptionHandler(ex); // handled
|
return StandardExceptionHandler(ex); // handled
|
||||||
},
|
},
|
||||||
useRetry: true);
|
useRetry: true,
|
||||||
|
azureAccountToken: azureAccountToken);
|
||||||
|
|
||||||
return isDatabaseReadOnly;
|
return isDatabaseReadOnly;
|
||||||
}
|
}
|
||||||
@@ -697,7 +701,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
public string MachineName;
|
public string MachineName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo)
|
public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo, string azureAccountToken)
|
||||||
{
|
{
|
||||||
serverInfo = null;
|
serverInfo = null;
|
||||||
if (!TryGetConnectionStringBuilder(connectionString, out SqlConnectionStringBuilder builder))
|
if (!TryGetConnectionStringBuilder(connectionString, out SqlConnectionStringBuilder builder))
|
||||||
@@ -705,14 +709,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
serverInfo = GetServerVersion(builder);
|
serverInfo = GetServerVersion(builder, azureAccountToken);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Returns the version of the server. This routine will throw if an exception is encountered.
|
/// Returns the version of the server. This routine will throw if an exception is encountered.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb)
|
public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb, string azureAccountToken)
|
||||||
{
|
{
|
||||||
Validate.IsNotNull(nameof(csb), csb);
|
Validate.IsNotNull(nameof(csb), csb);
|
||||||
ServerInfo serverInfo = null;
|
ServerInfo serverInfo = null;
|
||||||
@@ -724,7 +728,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
serverInfo = GetServerVersion(connection);
|
serverInfo = GetServerVersion(connection);
|
||||||
},
|
},
|
||||||
catchException: null, // Always throw
|
catchException: null, // Always throw
|
||||||
useRetry: true);
|
useRetry: true,
|
||||||
|
azureAccountToken: azureAccountToken);
|
||||||
|
|
||||||
return serverInfo;
|
return serverInfo;
|
||||||
}
|
}
|
||||||
@@ -1057,7 +1062,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
/// Returns true if the authenticating database is master, otherwise false. An example of
|
/// Returns true if the authenticating database is master, otherwise false. An example of
|
||||||
/// false is when the user is a contained user connecting to a contained database.
|
/// false is when the user is a contained user connecting to a contained database.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder)
|
public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder, string azureAccountToken)
|
||||||
{
|
{
|
||||||
bool authIsMaster = true;
|
bool authIsMaster = true;
|
||||||
OpenConnection(
|
OpenConnection(
|
||||||
@@ -1067,7 +1072,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
authIsMaster = IsAuthenticatingDatabaseMaster(connection);
|
authIsMaster = IsAuthenticatingDatabaseMaster(connection);
|
||||||
},
|
},
|
||||||
catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception
|
catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception
|
||||||
useRetry: true);
|
useRetry: true,
|
||||||
|
azureAccountToken: azureAccountToken);
|
||||||
return authIsMaster;
|
return authIsMaster;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
/// <param name="connectionString">The connection string used to open the SQL Azure database.</param>
|
/// <param name="connectionString">The connection string used to open the SQL Azure database.</param>
|
||||||
/// <param name="connectionRetryPolicy">The retry policy defining whether to retry a request if a connection fails to be established.</param>
|
/// <param name="connectionRetryPolicy">The retry policy defining whether to retry a request if a connection fails to be established.</param>
|
||||||
/// <param name="commandRetryPolicy">The retry policy defining whether to retry a request if a command fails to be executed.</param>
|
/// <param name="commandRetryPolicy">The retry policy defining whether to retry a request if a command fails to be executed.</param>
|
||||||
public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy)
|
public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy, string azureAccountToken)
|
||||||
{
|
{
|
||||||
_underlyingConnection = new SqlConnection(connectionString);
|
_underlyingConnection = new SqlConnection(connectionString);
|
||||||
_connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy();
|
_connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy();
|
||||||
@@ -68,6 +68,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection
|
|||||||
_underlyingConnection.StateChange += OnConnectionStateChange;
|
_underlyingConnection.StateChange += OnConnectionStateChange;
|
||||||
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
|
_connectionRetryPolicy.RetryOccurred += RetryConnectionCallback;
|
||||||
_commandRetryPolicy.RetryOccurred += RetryCommandCallback;
|
_commandRetryPolicy.RetryOccurred += RetryCommandCallback;
|
||||||
|
|
||||||
|
if (azureAccountToken != null)
|
||||||
|
{
|
||||||
|
_underlyingConnection.AccessToken = azureAccountToken;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates a new SqlConnection object
|
/// Creates a new SqlConnection object
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public DbConnection CreateSqlConnection(string connectionString)
|
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
|
||||||
{
|
{
|
||||||
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
||||||
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
||||||
return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy);
|
return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -213,8 +213,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation
|
|||||||
|
|
||||||
if (connInfo != null)
|
if (connInfo != null)
|
||||||
{
|
{
|
||||||
SqlConnection connection = ConnectionService.OpenSqlConnection(connInfo, "Restore");
|
Server server = new Server(ConnectionService.OpenServerConnection(connInfo, "Restore"));
|
||||||
Server server = new Server(new ServerConnection(connection));
|
|
||||||
|
|
||||||
RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, targetDatabaseName);
|
RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, targetDatabaseName);
|
||||||
return restoreDataObject;
|
return restoreDataObject;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ using System.Data.Common;
|
|||||||
using System.Data.SqlClient;
|
using System.Data.SqlClient;
|
||||||
using Microsoft.SqlServer.Management.Common;
|
using Microsoft.SqlServer.Management.Common;
|
||||||
using Microsoft.SqlServer.Management.Smo;
|
using Microsoft.SqlServer.Management.Smo;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
|
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
|
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
|
||||||
using Microsoft.SqlTools.Utility;
|
using Microsoft.SqlTools.Utility;
|
||||||
@@ -56,7 +57,16 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect with SMO and get the metadata for the table
|
// Connect with SMO and get the metadata for the table
|
||||||
Server server = new Server(new ServerConnection(sqlConn));
|
ServerConnection serverConnection;
|
||||||
|
if (sqlConn.AccessToken == null)
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(sqlConn);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken));
|
||||||
|
}
|
||||||
|
Server server = new Server(serverConnection);
|
||||||
Database db = new Database(server, sqlConn.Database);
|
Database db = new Database(server, sqlConn.Database);
|
||||||
|
|
||||||
TableViewTableTypeBase smoResult;
|
TableViewTableTypeBase smoResult;
|
||||||
|
|||||||
@@ -37,9 +37,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// Virtual method used to support mocking and testing
|
/// Virtual method used to support mocking and testing
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public virtual SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName)
|
public virtual ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName)
|
||||||
{
|
{
|
||||||
return ConnectionService.OpenSqlConnection(connInfo, featureName);
|
return ConnectionService.OpenServerConnection(connInfo, featureName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,10 +198,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
|
|||||||
try
|
try
|
||||||
{
|
{
|
||||||
bindingContext.BindingLock.Reset();
|
bindingContext.BindingLock.Reset();
|
||||||
SqlConnection sqlConn = connectionOpener.OpenSqlConnection(connInfo, featureName);
|
|
||||||
|
|
||||||
// populate the binding context to work with the SMO metadata provider
|
// populate the binding context to work with the SMO metadata provider
|
||||||
bindingContext.ServerConnection = new ServerConnection(sqlConn);
|
bindingContext.ServerConnection = connectionOpener.OpenServerConnection(connInfo, featureName);
|
||||||
|
|
||||||
if (this.needsMetadata)
|
if (this.needsMetadata)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -682,7 +682,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
|
|||||||
/// <param name="userName">User name for not trusted connections</param>
|
/// <param name="userName">User name for not trusted connections</param>
|
||||||
/// <param name="password">Password for not trusted connections</param>
|
/// <param name="password">Password for not trusted connections</param>
|
||||||
/// <param name="xmlParameters">XML string with parameters</param>
|
/// <param name="xmlParameters">XML string with parameters</param>
|
||||||
public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters)
|
public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters, string azureAccountToken = null)
|
||||||
{
|
{
|
||||||
this.serverType = serverType;
|
this.serverType = serverType;
|
||||||
this.serverName = serverName;
|
this.serverName = serverName;
|
||||||
@@ -690,7 +690,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
|
|||||||
if (serverType == ServerType.SQL)
|
if (serverType == ServerType.SQL)
|
||||||
{
|
{
|
||||||
// does some extra initialization
|
// does some extra initialization
|
||||||
ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName), true);
|
ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName, azureAccountToken), true);
|
||||||
|
|
||||||
// NOTE: ServerConnection property will constuct the object if needed
|
// NOTE: ServerConnection property will constuct the object if needed
|
||||||
m_server = new Server(ServerConnection);
|
m_server = new Server(ServerConnection);
|
||||||
@@ -1024,7 +1024,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
|
|||||||
bool trusted,
|
bool trusted,
|
||||||
string userName,
|
string userName,
|
||||||
SecureString password,
|
SecureString password,
|
||||||
string databaseName)
|
string databaseName,
|
||||||
|
string azureAccountToken)
|
||||||
{
|
{
|
||||||
SqlConnectionInfoWithConnection tempCI = new SqlConnectionInfoWithConnection(serverName);
|
SqlConnectionInfoWithConnection tempCI = new SqlConnectionInfoWithConnection(serverName);
|
||||||
tempCI.SingleConnection = false;
|
tempCI.SingleConnection = false;
|
||||||
@@ -1040,6 +1041,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
|
|||||||
tempCI.UserName = userName;
|
tempCI.UserName = userName;
|
||||||
tempCI.SecurePassword = password;
|
tempCI.SecurePassword = password;
|
||||||
}
|
}
|
||||||
|
|
||||||
tempCI.DatabaseName = databaseName;
|
tempCI.DatabaseName = databaseName;
|
||||||
|
|
||||||
return tempCI;
|
return tempCI;
|
||||||
@@ -1220,39 +1222,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Management
|
|||||||
containerDoc = CreateDataContainerDocument(connInfo, databaseExists);
|
containerDoc = CreateDataContainerDocument(connInfo, databaseExists);
|
||||||
}
|
}
|
||||||
|
|
||||||
CDataContainer dataContainer;
|
var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer");
|
||||||
|
|
||||||
// add alternate port to server name property if provided
|
var connectionInfoWithConnection = new SqlConnectionInfoWithConnection();
|
||||||
var connectionDetails = connInfo.ConnectionDetails;
|
connectionInfoWithConnection.ServerConnection = serverConnection;
|
||||||
string serverName = !connectionDetails.Port.HasValue
|
CDataContainer dataContainer = new CDataContainer(ServerType.SQL, connectionInfoWithConnection, true);
|
||||||
? connectionDetails.ServerName
|
dataContainer.Init(containerDoc);
|
||||||
: string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value);
|
|
||||||
|
|
||||||
// check if the connection is using SQL Auth or Integrated Auth
|
|
||||||
// TODO: ConnectionQueue try to get an existing connection (ConnectionQueue)
|
|
||||||
if (string.Equals(connectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase))
|
|
||||||
{
|
|
||||||
var passwordSecureString = BuildSecureStringFromPassword(connectionDetails.Password);
|
|
||||||
dataContainer = new CDataContainer(
|
|
||||||
CDataContainer.ServerType.SQL,
|
|
||||||
serverName,
|
|
||||||
false,
|
|
||||||
connectionDetails.UserName,
|
|
||||||
passwordSecureString,
|
|
||||||
connectionDetails.DatabaseName,
|
|
||||||
containerDoc.InnerXml);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
dataContainer = new CDataContainer(
|
|
||||||
CDataContainer.ServerType.SQL,
|
|
||||||
serverName,
|
|
||||||
true,
|
|
||||||
null,
|
|
||||||
null,
|
|
||||||
connectionDetails.DatabaseName,
|
|
||||||
containerDoc.InnerXml);
|
|
||||||
}
|
|
||||||
|
|
||||||
return dataContainer;
|
return dataContainer;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ using System.Data.Common;
|
|||||||
using System.Data.SqlClient;
|
using System.Data.SqlClient;
|
||||||
using Microsoft.SqlServer.Management.Common;
|
using Microsoft.SqlServer.Management.Common;
|
||||||
using Microsoft.SqlServer.Management.Smo;
|
using Microsoft.SqlServer.Management.Smo;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
|
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
|
using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters;
|
||||||
|
|
||||||
@@ -60,7 +61,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect with SMO and get the metadata for the table
|
// Connect with SMO and get the metadata for the table
|
||||||
Server server = new Server(new ServerConnection(sqlConn));
|
ServerConnection serverConnection;
|
||||||
|
if (sqlConn.AccessToken == null)
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(sqlConn);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken));
|
||||||
|
}
|
||||||
|
Server server = new Server(serverConnection);
|
||||||
Database database = server.Databases[sqlConn.Database];
|
Database database = server.Databases[sqlConn.Database];
|
||||||
TableViewTableTypeBase smoResult;
|
TableViewTableTypeBase smoResult;
|
||||||
switch (objectType.ToLowerInvariant())
|
switch (objectType.ToLowerInvariant())
|
||||||
|
|||||||
@@ -21,9 +21,9 @@
|
|||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
|
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
|
||||||
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
||||||
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" />
|
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<Compile Include="**\*.cs" />
|
<Compile Include="**\*.cs" />
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using System.Data.SqlClient;
|
using System.Data.SqlClient;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
|
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
|
||||||
using Microsoft.SqlTools.Utility;
|
using Microsoft.SqlTools.Utility;
|
||||||
using Microsoft.SqlServer.Management.Common;
|
using Microsoft.SqlServer.Management.Common;
|
||||||
@@ -42,10 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
|||||||
ServerConnection = serverConnection;
|
ServerConnection = serverConnection;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ScriptAsScriptingOperation(ScriptingParams parameters) : base(parameters)
|
public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccountToken) : base(parameters)
|
||||||
{
|
{
|
||||||
SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString);
|
SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString);
|
||||||
|
if (azureAccountToken != null)
|
||||||
|
{
|
||||||
|
sqlConnection.AccessToken = azureAccountToken;
|
||||||
|
}
|
||||||
|
|
||||||
ServerConnection = new ServerConnection(sqlConnection);
|
ServerConnection = new ServerConnection(sqlConnection);
|
||||||
|
if (azureAccountToken != null)
|
||||||
|
{
|
||||||
|
ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken);
|
||||||
|
}
|
||||||
|
|
||||||
disconnectAtDispose = true;
|
disconnectAtDispose = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
|||||||
|
|
||||||
private int eventSequenceNumber = 1;
|
private int eventSequenceNumber = 1;
|
||||||
|
|
||||||
public ScriptingScriptOperation(ScriptingParams parameters): base(parameters)
|
private string azureAccessToken;
|
||||||
|
|
||||||
|
public ScriptingScriptOperation(ScriptingParams parameters, string azureAccessToken): base(parameters)
|
||||||
{
|
{
|
||||||
|
this.azureAccessToken = azureAccessToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
public override void Execute()
|
public override void Execute()
|
||||||
@@ -200,7 +203,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
|||||||
selectedObjects.Count(),
|
selectedObjects.Count(),
|
||||||
string.Join(", ", selectedObjects)));
|
string.Join(", ", selectedObjects)));
|
||||||
|
|
||||||
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString);
|
string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString, this.azureAccessToken);
|
||||||
string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog;
|
string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog;
|
||||||
|
|
||||||
foreach (ScriptingObject scriptingObject in selectedObjects)
|
foreach (ScriptingObject scriptingObject in selectedObjects)
|
||||||
|
|||||||
@@ -111,12 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
|||||||
// use the owner uri property to lookup its associated ConnectionInfo
|
// use the owner uri property to lookup its associated ConnectionInfo
|
||||||
// and then build a connection string out of that
|
// and then build a connection string out of that
|
||||||
ConnectionInfo connInfo = null;
|
ConnectionInfo connInfo = null;
|
||||||
|
string accessToken = null;
|
||||||
if (parameters.ConnectionString == null)
|
if (parameters.ConnectionString == null)
|
||||||
{
|
{
|
||||||
ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
|
ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo);
|
||||||
if (connInfo != null)
|
if (connInfo != null)
|
||||||
{
|
{
|
||||||
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
|
parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails);
|
||||||
|
accessToken = connInfo.ConnectionDetails.AzureAccountToken;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -126,11 +128,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
|||||||
|
|
||||||
if (!ShouldCreateScriptAsOperation(parameters))
|
if (!ShouldCreateScriptAsOperation(parameters))
|
||||||
{
|
{
|
||||||
operation = new ScriptingScriptOperation(parameters);
|
operation = new ScriptingScriptOperation(parameters, accessToken);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
operation = new ScriptAsScriptingOperation(parameters);
|
operation = new ScriptAsScriptingOperation(parameters, accessToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait();
|
operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait();
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
using Microsoft.SqlServer.Management.Common;
|
using Microsoft.SqlServer.Management.Common;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
|
using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts;
|
||||||
using Microsoft.SqlTools.Utility;
|
using Microsoft.SqlTools.Utility;
|
||||||
using System;
|
using System;
|
||||||
@@ -71,17 +72,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting
|
|||||||
parameters.OperationId = this.OperationId;
|
parameters.OperationId = this.OperationId;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected string GetServerNameFromLiveInstance(string connectionString)
|
protected string GetServerNameFromLiveInstance(string connectionString, string azureAccessToken)
|
||||||
{
|
{
|
||||||
string serverName = null;
|
string serverName = null;
|
||||||
using (SqlConnection connection = new SqlConnection(connectionString))
|
using (SqlConnection connection = new SqlConnection(connectionString))
|
||||||
{
|
{
|
||||||
|
if (azureAccessToken != null)
|
||||||
|
{
|
||||||
|
connection.AccessToken = azureAccessToken;
|
||||||
|
}
|
||||||
connection.Open();
|
connection.Open();
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
ServerConnection serverConnection;
|
||||||
ServerConnection serverConnection = new ServerConnection(connection);
|
if (azureAccessToken == null)
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(connection);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
serverConnection = new ServerConnection(connection, new AzureAccessToken(azureAccessToken));
|
||||||
|
}
|
||||||
serverName = serverConnection.TrueName;
|
serverName = serverConnection.TrueName;
|
||||||
}
|
}
|
||||||
catch (SqlException e)
|
catch (SqlException e)
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
||||||
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
||||||
|
|
||||||
ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy);
|
ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy, null);
|
||||||
return connection;
|
return connection;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,7 +283,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn);
|
logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn);
|
||||||
},
|
},
|
||||||
catchException: null,
|
catchException: null,
|
||||||
useRetry: false);
|
useRetry: false,
|
||||||
|
azureAccountToken: null);
|
||||||
|
|
||||||
Assert.False(string.IsNullOrWhiteSpace(filePath));
|
Assert.False(string.IsNullOrWhiteSpace(filePath));
|
||||||
Assert.False(string.IsNullOrWhiteSpace(logPath));
|
Assert.False(string.IsNullOrWhiteSpace(logPath));
|
||||||
@@ -342,7 +343,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
var connectionBuilder = CreateTestConnectionStringBuilder();
|
var connectionBuilder = CreateTestConnectionStringBuilder();
|
||||||
Assert.NotNull(connectionBuilder);
|
Assert.NotNull(connectionBuilder);
|
||||||
|
|
||||||
bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder);
|
bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder, null);
|
||||||
Assert.False(isReadOnly);
|
Assert.False(isReadOnly);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -352,7 +353,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
[Fact]
|
[Fact]
|
||||||
public void TestIsDatabaseReadonlyWithNullBuilder()
|
public void TestIsDatabaseReadonlyWithNullBuilder()
|
||||||
{
|
{
|
||||||
Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null));
|
Assert.Throws<ArgumentNullException>(() => ReliableConnectionHelper.IsDatabaseReadonly(null, null));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -361,7 +362,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
[Fact]
|
[Fact]
|
||||||
public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed()
|
public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed()
|
||||||
{
|
{
|
||||||
using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true))
|
using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true, azureAccountToken: null))
|
||||||
{
|
{
|
||||||
VerifySessionSettings(conn, true);
|
VerifySessionSettings(conn, true);
|
||||||
VerifySessionSettings(conn, false);
|
VerifySessionSettings(conn, false);
|
||||||
@@ -506,7 +507,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
"SET NOCOUNT ON; SET NOCOUNT OFF;",
|
"SET NOCOUNT ON; SET NOCOUNT OFF;",
|
||||||
ReliableConnectionHelper.SetCommandTimeout,
|
ReliableConnectionHelper.SetCommandTimeout,
|
||||||
null,
|
null,
|
||||||
true
|
true,
|
||||||
|
null
|
||||||
);
|
);
|
||||||
Assert.NotNull(result);
|
Assert.NotNull(result);
|
||||||
}
|
}
|
||||||
@@ -519,7 +521,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
{
|
{
|
||||||
ReliableConnectionHelper.ServerInfo info = null;
|
ReliableConnectionHelper.ServerInfo info = null;
|
||||||
var connBuilder = CreateTestConnectionStringBuilder();
|
var connBuilder = CreateTestConnectionStringBuilder();
|
||||||
Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info));
|
Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info, null));
|
||||||
|
|
||||||
Assert.NotNull(info);
|
Assert.NotNull(info);
|
||||||
Assert.NotNull(info.ServerVersion);
|
Assert.NotNull(info.ServerVersion);
|
||||||
@@ -535,7 +537,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
RunIfWrapper.RunIfWindows(() =>
|
RunIfWrapper.RunIfWindows(() =>
|
||||||
{
|
{
|
||||||
ReliableConnectionHelper.ServerInfo info = null;
|
ReliableConnectionHelper.ServerInfo info = null;
|
||||||
Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info));
|
Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info, null));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -679,12 +681,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
|
|||||||
var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath);
|
var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath);
|
||||||
ConnectionInfo connInfo = result.ConnectionInfo;
|
ConnectionInfo connInfo = result.ConnectionInfo;
|
||||||
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
|
DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
|
||||||
|
connection.Open();
|
||||||
|
|
||||||
Assert.True(connection.State == ConnectionState.Open, "Connection should be open.");
|
Assert.True(connection.State == ConnectionState.Open, "Connection should be open.");
|
||||||
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
|
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
|
||||||
|
|
||||||
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
|
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
|
||||||
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder));
|
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder, null));
|
||||||
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
|
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
|
||||||
|
|
||||||
Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));
|
Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));
|
||||||
|
|||||||
@@ -33,7 +33,7 @@
|
|||||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
|
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
|
||||||
<PackageReference Include="xunit" Version="2.2.0" />
|
<PackageReference Include="xunit" Version="2.2.0" />
|
||||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
||||||
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
|
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
|
||||||
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
|
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
|
||||||
<PackageReference Include="xunit" Version="2.2.0" />
|
<PackageReference Include="xunit" Version="2.2.0" />
|
||||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
||||||
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
|
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
|
||||||
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
|
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.3.0" />
|
||||||
<PackageReference Include="xunit" Version="2.2.0" />
|
<PackageReference Include="xunit" Version="2.2.0" />
|
||||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
||||||
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
|
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
|
||||||
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
});
|
});
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns(mockConnection.Object);
|
.Returns(mockConnection.Object);
|
||||||
|
|
||||||
|
|
||||||
@@ -146,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
.Returns(() => Task.Run(() => {}));
|
.Returns(() => Task.Run(() => {}));
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns(mockConnection.Object)
|
.Returns(mockConnection.Object)
|
||||||
.Returns(mockConnection2.Object);
|
.Returns(mockConnection2.Object);
|
||||||
|
|
||||||
@@ -209,7 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
});
|
});
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns(mockConnection.Object);
|
.Returns(mockConnection.Object);
|
||||||
|
|
||||||
|
|
||||||
@@ -282,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
connectionMock.Setup(c => c.Database).Returns(expectedDbName);
|
connectionMock.Setup(c => c.Database).Returns(expectedDbName);
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns(connectionMock.Object);
|
.Returns(connectionMock.Object);
|
||||||
|
|
||||||
var connectionService = new ConnectionService(mockFactory.Object);
|
var connectionService = new ConnectionService(mockFactory.Object);
|
||||||
@@ -321,8 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
var dummySqlConnection = new TestSqlConnection(null);
|
var dummySqlConnection = new TestSqlConnection(null);
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns((string connString) =>
|
.Returns((string connString, string azureAccountToken) =>
|
||||||
{
|
{
|
||||||
dummySqlConnection.ConnectionString = connString;
|
dummySqlConnection.ConnectionString = connString;
|
||||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||||
@@ -775,7 +775,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
|
|
||||||
// Setup mock connection factory to inject query results
|
// Setup mock connection factory to inject query results
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns(CreateMockDbConnection(new[] {data}));
|
.Returns(CreateMockDbConnection(new[] {data}));
|
||||||
var connectionService = new ConnectionService(mockFactory.Object);
|
var connectionService = new ConnectionService(mockFactory.Object);
|
||||||
|
|
||||||
@@ -1324,8 +1324,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
var connection = new TestSqlConnection(null);
|
var connection = new TestSqlConnection(null);
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns((string connString) =>
|
.Returns((string connString, string azureAccountToken) =>
|
||||||
{
|
{
|
||||||
connection.ConnectionString = connString;
|
connection.ConnectionString = connString;
|
||||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||||
@@ -1374,8 +1374,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
var connection = mockConnection.Object;
|
var connection = mockConnection.Object;
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns((string connString) =>
|
.Returns((string connString, string azureAccountToken) =>
|
||||||
{
|
{
|
||||||
connection.ConnectionString = connString;
|
connection.ConnectionString = connString;
|
||||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||||
@@ -1427,8 +1427,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
var connection = mockConnection.Object;
|
var connection = mockConnection.Object;
|
||||||
|
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns((string connString) =>
|
.Returns((string connString, string azureAccountToken) =>
|
||||||
{
|
{
|
||||||
connection.ConnectionString = connString;
|
connection.ConnectionString = connString;
|
||||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||||
@@ -1484,5 +1484,32 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
|||||||
Assert.Equal(false, details.TrustServerCertificate);
|
Assert.Equal(false, details.TrustServerCertificate);
|
||||||
Assert.Equal(30, details.ConnectTimeout);
|
Assert.Equal(30, details.ConnectTimeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async void ConnectingWithAzureAccountUsesToken()
|
||||||
|
{
|
||||||
|
// Set up mock connection factory
|
||||||
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
|
.Returns(new TestSqlConnection(null));
|
||||||
|
var connectionService = new ConnectionService(mockFactory.Object);
|
||||||
|
|
||||||
|
var details = TestObjects.GetTestConnectionDetails();
|
||||||
|
var azureAccountToken = "testAzureAccountToken";
|
||||||
|
details.AzureAccountToken = azureAccountToken;
|
||||||
|
details.UserName = "";
|
||||||
|
details.Password = "";
|
||||||
|
details.AuthenticationType = "AzureMFA";
|
||||||
|
|
||||||
|
// If I open a connection using connection details that include an account token
|
||||||
|
await connectionService.Connect(new ConnectParams
|
||||||
|
{
|
||||||
|
OwnerUri = "testURI",
|
||||||
|
Connection = details
|
||||||
|
});
|
||||||
|
|
||||||
|
// Then the connection factory got called with details including an account token
|
||||||
|
mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.Is<string>(accountToken => accountToken == azureAccountToken)), Times.Once());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Tests for ReliableConnection code
|
||||||
|
/// </summary>
|
||||||
|
public class ReliableConnectionTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public void ReliableSqlConnectionUsesAzureToken()
|
||||||
|
{
|
||||||
|
ConnectionDetails details = TestObjects.GetTestConnectionDetails();
|
||||||
|
details.UserName = "";
|
||||||
|
details.Password = "";
|
||||||
|
string connectionString = ConnectionService.BuildConnectionString(details);
|
||||||
|
string azureAccountToken = "testAzureAccountToken";
|
||||||
|
RetryPolicy retryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy();
|
||||||
|
|
||||||
|
// If I create a ReliableSqlConnection using an azure account token
|
||||||
|
var reliableConnection = new ReliableSqlConnection(connectionString, retryPolicy, retryPolicy, azureAccountToken);
|
||||||
|
|
||||||
|
// Then the connection's azureAccountToken gets set
|
||||||
|
Assert.Equal(azureAccountToken, reliableConnection.GetUnderlyingConnection().AccessToken);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,8 +14,8 @@
|
|||||||
<PackageReference Include="NUnit" Version="3.10.1" />
|
<PackageReference Include="NUnit" Version="3.10.1" />
|
||||||
<PackageReference Include="xunit" Version="2.2.0" />
|
<PackageReference Include="xunit" Version="2.2.0" />
|
||||||
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
<PackageReference Include="xunit.runner.visualstudio" Version="2.2.0" />
|
||||||
<PackageReference Include="System.Data.SqlClient" Version="4.5.0" />
|
<PackageReference Include="System.Data.SqlClient" Version="4.6.0-preview3-27014-02" />
|
||||||
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.5.0" />
|
<PackageReference Include="System.Text.Encoding.CodePages" Version="4.6.0-preview3-26501-04" />
|
||||||
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
<PackageReference Include="Microsoft.SqlServer.SqlManagementObjects" Version="$(SmoPackageVersion)" />
|
||||||
</ItemGroup>
|
</ItemGroup>
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
|||||||
@@ -408,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
|
|||||||
|
|
||||||
// Stub out the connection to avoid a 30second timeout while attempting to connect.
|
// Stub out the connection to avoid a 30second timeout while attempting to connect.
|
||||||
// The tests don't need any connection context anyhow so this doesn't impact the scenario
|
// The tests don't need any connection context anyhow so this doesn't impact the scenario
|
||||||
mockConnectionOpener.Setup(b => b.OpenSqlConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
|
mockConnectionOpener.Setup(b => b.OpenServerConnection(It.IsAny<ConnectionInfo>(), It.IsAny<string>()))
|
||||||
.Throws<Exception>();
|
.Throws<Exception>();
|
||||||
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
|
connectionServiceMock.Setup(c => c.Connect(It.IsAny<ConnectParams>()))
|
||||||
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));
|
.Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details)));
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
|
|||||||
private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead)
|
private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead)
|
||||||
{
|
{
|
||||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
|
||||||
.Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead));
|
.Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead));
|
||||||
|
|
||||||
return mockFactory.Object;
|
return mockFactory.Object;
|
||||||
@@ -184,7 +184,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution
|
|||||||
// Create a connection info and add the default connection to it
|
// Create a connection info and add the default connection to it
|
||||||
ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead);
|
ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead);
|
||||||
ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails);
|
ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails);
|
||||||
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null);
|
ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null, null);
|
||||||
return ci;
|
return ci;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -427,7 +427,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
|
|||||||
|
|
||||||
private static DbConnection GetConnection(ConnectionInfo info)
|
private static DbConnection GetConnection(ConnectionInfo info)
|
||||||
{
|
{
|
||||||
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
|
return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
[SuppressMessage("ReSharper", "UnusedParameter.Local")]
|
[SuppressMessage("ReSharper", "UnusedParameter.Local")]
|
||||||
|
|||||||
@@ -386,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
|
|||||||
private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query)
|
private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query)
|
||||||
{
|
{
|
||||||
var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead);
|
var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead);
|
||||||
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
|
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
|
||||||
var command = connection.CreateCommand();
|
var command = connection.CreateCommand();
|
||||||
command.CommandText = query;
|
command.CommandText = query;
|
||||||
return command.ExecuteReader();
|
return command.ExecuteReader();
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
|
|||||||
private static DbDataReader GetReader(TestResultSet[] dataSet, string query)
|
private static DbDataReader GetReader(TestResultSet[] dataSet, string query)
|
||||||
{
|
{
|
||||||
var info = Common.CreateTestConnectionInfo(dataSet, false, false);
|
var info = Common.CreateTestConnectionInfo(dataSet, false, false);
|
||||||
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails));
|
var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null);
|
||||||
var command = connection.CreateCommand();
|
var command = connection.CreateCommand();
|
||||||
command.CommandText = query;
|
command.CommandText = query;
|
||||||
return command.ExecuteReader();
|
return command.ExecuteReader();
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ using System.Collections.Generic;
|
|||||||
using System.Data;
|
using System.Data;
|
||||||
using System.Data.Common;
|
using System.Data.Common;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
|
using System.Threading;
|
||||||
|
using System.Threading.Tasks;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Connection;
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||||
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
|
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
|
||||||
@@ -273,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
|
|||||||
/// </summary>
|
/// </summary>
|
||||||
public class TestSqlConnectionFactory : ISqlConnectionFactory
|
public class TestSqlConnectionFactory : ISqlConnectionFactory
|
||||||
{
|
{
|
||||||
public DbConnection CreateSqlConnection(string connectionString)
|
public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken)
|
||||||
{
|
{
|
||||||
return new TestSqlConnection(null)
|
return new TestSqlConnection(null)
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user