diff --git a/src/Microsoft.SqlTools.ServiceLayer/Agent/AgentService.cs b/src/Microsoft.SqlTools.ServiceLayer/Agent/AgentService.cs index a61e70d7..4c24ea7e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Agent/AgentService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Agent/AgentService.cs @@ -143,8 +143,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent if (connInfo != null) { - var sqlConnection = ConnectionService.OpenSqlConnection(connInfo); - var serverConnection = new ServerConnection(sqlConnection); + var serverConnection = ConnectionService.OpenServerConnection(connInfo); var fetcher = new JobFetcher(serverConnection); var filter = new JobActivityFilter(); var jobs = fetcher.FetchJobs(filter); @@ -158,7 +157,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent } result.Success = true; result.Jobs = agentJobs.ToArray(); - sqlConnection.Close(); + serverConnection.SqlConnectionObject.Close(); } await requestContext.SendResult(result); } @@ -269,8 +268,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent out connInfo); if (connInfo != null) { - var sqlConnection = ConnectionService.OpenSqlConnection(connInfo); - var serverConnection = new ServerConnection(sqlConnection); + var serverConnection = ConnectionService.OpenServerConnection(connInfo); var jobHelper = new JobHelper(serverConnection); jobHelper.JobName = parameters.JobName; switch(parameters.Action) @@ -1163,8 +1161,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Agent private Tuple CreateSqlConnection(ConnectionInfo connInfo, String jobId) { - var sqlConnection = ConnectionService.OpenSqlConnection(connInfo); - var serverConnection = new ServerConnection(sqlConnection); + var serverConnection = ConnectionService.OpenServerConnection(connInfo); var server = new Server(serverConnection); var filter = new JobHistoryFilter(); filter.JobID = new Guid(jobId); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs index 07970742..523809e2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionProviderOptionsHelper.cs @@ -48,8 +48,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection ValueType = ConnectionOption.ValueTypeCategory, SpecialValueType = ConnectionOption.SpecialValueAuthType, CategoryValues = new CategoryValue[] - { new CategoryValue {DisplayName = "SQL Login", Name = "SqlLogin" }, - new CategoryValue {DisplayName = "Windows Authentication", Name= "Integrated" } + { new CategoryValue { DisplayName = "SQL Login", Name = "SqlLogin" }, + new CategoryValue { DisplayName = "Windows Authentication", Name = "Integrated" }, + new CategoryValue { DisplayName = "Azure Active Directory - Universal with MFA support", Name = "AzureMFA" } }, IsIdentity = true, IsRequired = true, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index f4115a73..480eaaed 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -523,7 +523,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); // create a sql connection instance - connection = connectionInfo.Factory.CreateSqlConnection(connectionString); + connection = connectionInfo.Factory.CreateSqlConnection(connectionString, connectionInfo.ConnectionDetails.AzureAccountToken); connectionInfo.AddConnection(connectionParams.Type, connection); // Add a cancellation token source so that the connection OpenAsync() can be cancelled @@ -909,7 +909,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Connect to master and query sys.databases connectionDetails.DatabaseName = "master"; - var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails)); + var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken); connection.Open(); List results = new List(); @@ -1151,6 +1151,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection break; case "SqlLogin": break; + case "AzureMFA": + connectionBuilder.UserID = ""; + connectionBuilder.Password = ""; + break; default: throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType)); } @@ -1387,7 +1391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection string connectionString = BuildConnectionString(info.ConnectionDetails); // create a sql connection instance - DbConnection connection = info.Factory.CreateSqlConnection(connectionString); + DbConnection connection = info.Factory.CreateSqlConnection(connectionString, info.ConnectionDetails.AzureAccountToken); connection.Open(); info.AddConnection(key, connection); } @@ -1488,6 +1492,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// Note: we need to audit all uses of this method to determine why we're /// bypassing normal ConnectionService connection management /// + /// The connection info to connect with + /// A plaintext string that will be included in the application name for the connection + /// A SqlConnection created with the given connection info internal static SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName = null) { try @@ -1515,6 +1522,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // open a dedicated binding server connection SqlConnection sqlConn = new SqlConnection(connectionString); + + // Fill in Azure authentication token if needed + if (connInfo.ConnectionDetails.AzureAccountToken != null) + { + sqlConn.AccessToken = connInfo.ConnectionDetails.AzureAccountToken; + } + sqlConn.Open(); return sqlConn; } @@ -1529,6 +1543,30 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return null; } + /// + /// Create and open a new ServerConnection from a ConnectionInfo object. + /// This calls ConnectionService.OpenSqlConnection and then creates a + /// ServerConnection from it. + /// + /// The connection info to connect with + /// A plaintext string that will be included in the application name for the connection + /// A ServerConnection (wrapping a SqlConnection) created with the given connection info + internal static ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName = null) + { + var sqlConnection = ConnectionService.OpenSqlConnection(connInfo, featureName); + ServerConnection serverConnection; + if (connInfo.ConnectionDetails.AzureAccountToken != null) + { + serverConnection = new ServerConnection(sqlConnection, new AzureAccessToken(connInfo.ConnectionDetails.AzureAccountToken)); + } + else + { + serverConnection = new ServerConnection(sqlConnection); + } + + return serverConnection; + } + public static void EnsureConnectionIsOpen(DbConnection conn, bool forceReopen = false) { // verify that the connection is open @@ -1552,4 +1590,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } } + + public class AzureAccessToken : IRenewableToken + { + public DateTimeOffset TokenExpiry { get; set; } + public string Resource { get; set; } + public string Tenant { get; set; } + public string UserId { get; set; } + + private string accessToken; + + public AzureAccessToken(string accessToken) + { + this.accessToken = accessToken; + } + + public string GetAccessToken() + { + return this.accessToken; + } + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs index c7347021..62fe86b3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -497,6 +497,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts } } + public string AzureAccountToken + { + get + { + return GetOptionValue("azureAccountToken"); + } + set + { + SetOptionValue("azureAccountToken", value); + } + } + public bool IsComparableTo(ConnectionDetails other) { if (other == null) @@ -506,7 +518,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts if (ServerName != other.ServerName || AuthenticationType != other.AuthenticationType - || UserName != other.UserName) + || UserName != other.UserName + || AzureAccountToken != other.AzureAccountToken) { return false; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs index 36c2dde4..aeee8760 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetailsExtensions.cs @@ -44,7 +44,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts PacketSize = details.PacketSize, TypeSystemVersion = details.TypeSystemVersion, ConnectionString = details.ConnectionString, - Port = details.Port + Port = details.Port, + AzureAccountToken = details.AzureAccountToken }; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ISqlConnectionFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ISqlConnectionFactory.cs index ed0cc01b..a0dada33 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ISqlConnectionFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ISqlConnectionFactory.cs @@ -15,6 +15,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// /// Create a new SQL Connection object /// - DbConnection CreateSqlConnection(string connectionString); + DbConnection CreateSqlConnection(string connectionString, string azureAccountToken); } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs index 297cf28c..77c055e5 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableConnectionHelper.cs @@ -46,10 +46,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// Opens the connection and sets the lock/command timeout and pooling=false. /// /// The opened connection - public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry) + public static IDbConnection OpenConnection(SqlConnectionStringBuilder csb, bool useRetry, string azureAccountToken) { csb.Pooling = false; - return OpenConnection(csb.ToString(), useRetry); + return OpenConnection(csb.ToString(), useRetry, azureAccountToken); } /// @@ -57,7 +57,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// will assert if pooling!=false. /// /// The opened connection - public static IDbConnection OpenConnection(string connectionString, bool useRetry) + public static IDbConnection OpenConnection(string connectionString, bool useRetry, string azureAccountToken) { #if DEBUG try @@ -88,7 +88,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection connectionRetryPolicy = RetryPolicyFactory.CreateNoRetryPolicy(); } - ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy); + ReliableSqlConnection connection = new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken); try { @@ -136,7 +136,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection SqlConnectionStringBuilder csb, Action usingConnection, Predicate catchException, - bool useRetry) + bool useRetry, + string azureAccountToken) { Validate.IsNotNull(nameof(csb), csb); Validate.IsNotNull(nameof(usingConnection), usingConnection); @@ -145,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection { // Always disable pooling csb.Pooling = false; - using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry)) + using (IDbConnection conn = OpenConnection(csb.ConnectionString, useRetry, azureAccountToken)) { usingConnection(conn); } @@ -228,7 +229,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection string commandText, Action initializeCommand, Predicate catchException, - bool useRetry) + bool useRetry, + string azureAccountToken) { object retObject = null; OpenConnection( @@ -238,7 +240,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection retObject = ExecuteNonQuery(connection, commandText, initializeCommand, catchException); }, catchException, - useRetry); + useRetry, + azureAccountToken); return retObject; } @@ -636,7 +639,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// /// Returns true if the database is readonly. This routine will swallow the exceptions you might expect from SQL using StandardExceptionHandler. /// - public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder) + public static bool IsDatabaseReadonly(SqlConnectionStringBuilder builder, string azureAccountToken) { Validate.IsNotNull(nameof(builder), builder); @@ -670,7 +673,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection Logger.Write(TraceEventType.Error, ex.ToString()); return StandardExceptionHandler(ex); // handled }, - useRetry: true); + useRetry: true, + azureAccountToken: azureAccountToken); return isDatabaseReadOnly; } @@ -697,7 +701,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection public string MachineName; } - public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo) + public static bool TryGetServerVersion(string connectionString, out ServerInfo serverInfo, string azureAccountToken) { serverInfo = null; if (!TryGetConnectionStringBuilder(connectionString, out SqlConnectionStringBuilder builder)) @@ -705,14 +709,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection return false; } - serverInfo = GetServerVersion(builder); + serverInfo = GetServerVersion(builder, azureAccountToken); return true; } /// /// Returns the version of the server. This routine will throw if an exception is encountered. /// - public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb) + public static ServerInfo GetServerVersion(SqlConnectionStringBuilder csb, string azureAccountToken) { Validate.IsNotNull(nameof(csb), csb); ServerInfo serverInfo = null; @@ -724,7 +728,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection serverInfo = GetServerVersion(connection); }, catchException: null, // Always throw - useRetry: true); + useRetry: true, + azureAccountToken: azureAccountToken); return serverInfo; } @@ -1057,7 +1062,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// Returns true if the authenticating database is master, otherwise false. An example of /// false is when the user is a contained user connecting to a contained database. /// - public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder) + public static bool IsAuthenticatingDatabaseMaster(SqlConnectionStringBuilder builder, string azureAccountToken) { bool authIsMaster = true; OpenConnection( @@ -1067,7 +1072,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection authIsMaster = IsAuthenticatingDatabaseMaster(connection); }, catchException: StandardExceptionHandler, // Don't throw unless it's an unexpected exception - useRetry: true); + useRetry: true, + azureAccountToken: azureAccountToken); return authIsMaster; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs index f498338e..3d9ebbf4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ReliableConnection/ReliableSqlConnection.cs @@ -59,7 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection /// The connection string used to open the SQL Azure database. /// The retry policy defining whether to retry a request if a connection fails to be established. /// The retry policy defining whether to retry a request if a command fails to be executed. - public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy) + public ReliableSqlConnection(string connectionString, RetryPolicy connectionRetryPolicy, RetryPolicy commandRetryPolicy, string azureAccountToken) { _underlyingConnection = new SqlConnection(connectionString); _connectionRetryPolicy = connectionRetryPolicy ?? RetryPolicyFactory.CreateNoRetryPolicy(); @@ -68,6 +68,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection _underlyingConnection.StateChange += OnConnectionStateChange; _connectionRetryPolicy.RetryOccurred += RetryConnectionCallback; _commandRetryPolicy.RetryOccurred += RetryCommandCallback; + + if (azureAccountToken != null) + { + _underlyingConnection.AccessToken = azureAccountToken; + } } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/SqlConnectionFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/SqlConnectionFactory.cs index 4cafb290..f890a72a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/SqlConnectionFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/SqlConnectionFactory.cs @@ -18,11 +18,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// /// Creates a new SqlConnection object /// - public DbConnection CreateSqlConnection(string connectionString) + public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken) { RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); - return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy); + return new ReliableSqlConnection(connectionString, connectionRetryPolicy, commandRetryPolicy, azureAccountToken); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs index c99ac928..6646f203 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/DisasterRecovery/RestoreOperation/RestoreDatabaseHelper.cs @@ -213,8 +213,7 @@ namespace Microsoft.SqlTools.ServiceLayer.DisasterRecovery.RestoreOperation if (connInfo != null) { - SqlConnection connection = ConnectionService.OpenSqlConnection(connInfo, "Restore"); - Server server = new Server(new ServerConnection(connection)); + Server server = new Server(ConnectionService.OpenServerConnection(connInfo, "Restore")); RestoreDatabaseTaskDataObject restoreDataObject = new RestoreDatabaseTaskDataObject(server, targetDatabaseName); return restoreDataObject; diff --git a/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs index cf0b67ac..319f1fcd 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/EditData/SmoEditMetadataFactory.cs @@ -9,6 +9,7 @@ using System.Data.Common; using System.Data.SqlClient; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters; using Microsoft.SqlTools.Utility; @@ -56,7 +57,16 @@ namespace Microsoft.SqlTools.ServiceLayer.EditData } // Connect with SMO and get the metadata for the table - Server server = new Server(new ServerConnection(sqlConn)); + ServerConnection serverConnection; + if (sqlConn.AccessToken == null) + { + serverConnection = new ServerConnection(sqlConn); + } + else + { + serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken)); + } + Server server = new Server(serverConnection); Database db = new Database(server, sqlConn.Database); TableViewTableTypeBase smoResult; diff --git a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs index 85b9d33c..aac80371 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/LanguageServices/ConnectedBindingQueue.cs @@ -37,9 +37,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices /// /// Virtual method used to support mocking and testing /// - public virtual SqlConnection OpenSqlConnection(ConnectionInfo connInfo, string featureName) + public virtual ServerConnection OpenServerConnection(ConnectionInfo connInfo, string featureName) { - return ConnectionService.OpenSqlConnection(connInfo, featureName); + return ConnectionService.OpenServerConnection(connInfo, featureName); } } @@ -198,10 +198,9 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices try { bindingContext.BindingLock.Reset(); - SqlConnection sqlConn = connectionOpener.OpenSqlConnection(connInfo, featureName); // populate the binding context to work with the SMO metadata provider - bindingContext.ServerConnection = new ServerConnection(sqlConn); + bindingContext.ServerConnection = connectionOpener.OpenServerConnection(connInfo, featureName); if (this.needsMetadata) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/Management/Common/DataContainer.cs b/src/Microsoft.SqlTools.ServiceLayer/Management/Common/DataContainer.cs index a7a6ce28..9efa11f7 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Management/Common/DataContainer.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Management/Common/DataContainer.cs @@ -682,7 +682,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management /// User name for not trusted connections /// Password for not trusted connections /// XML string with parameters - public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters) + public CDataContainer(ServerType serverType, string serverName, bool trusted, string userName, SecureString password, string databaseName, string xmlParameters, string azureAccountToken = null) { this.serverType = serverType; this.serverName = serverName; @@ -690,7 +690,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management if (serverType == ServerType.SQL) { // does some extra initialization - ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName), true); + ApplyConnectionInfo(GetTempSqlConnectionInfoWithConnection(serverName, trusted, userName, password, databaseName, azureAccountToken), true); // NOTE: ServerConnection property will constuct the object if needed m_server = new Server(ServerConnection); @@ -1024,7 +1024,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Management bool trusted, string userName, SecureString password, - string databaseName) + string databaseName, + string azureAccountToken) { SqlConnectionInfoWithConnection tempCI = new SqlConnectionInfoWithConnection(serverName); tempCI.SingleConnection = false; @@ -1040,6 +1041,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Management tempCI.UserName = userName; tempCI.SecurePassword = password; } + tempCI.DatabaseName = databaseName; return tempCI; @@ -1220,39 +1222,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Management containerDoc = CreateDataContainerDocument(connInfo, databaseExists); } - CDataContainer dataContainer; + var serverConnection = ConnectionService.OpenServerConnection(connInfo, "DataContainer"); - // add alternate port to server name property if provided - var connectionDetails = connInfo.ConnectionDetails; - string serverName = !connectionDetails.Port.HasValue - ? connectionDetails.ServerName - : string.Format("{0},{1}", connectionDetails.ServerName, connectionDetails.Port.Value); - - // check if the connection is using SQL Auth or Integrated Auth - // TODO: ConnectionQueue try to get an existing connection (ConnectionQueue) - if (string.Equals(connectionDetails.AuthenticationType, "SqlLogin", StringComparison.OrdinalIgnoreCase)) - { - var passwordSecureString = BuildSecureStringFromPassword(connectionDetails.Password); - dataContainer = new CDataContainer( - CDataContainer.ServerType.SQL, - serverName, - false, - connectionDetails.UserName, - passwordSecureString, - connectionDetails.DatabaseName, - containerDoc.InnerXml); - } - else - { - dataContainer = new CDataContainer( - CDataContainer.ServerType.SQL, - serverName, - true, - null, - null, - connectionDetails.DatabaseName, - containerDoc.InnerXml); - } + var connectionInfoWithConnection = new SqlConnectionInfoWithConnection(); + connectionInfoWithConnection.ServerConnection = serverConnection; + CDataContainer dataContainer = new CDataContainer(ServerType.SQL, connectionInfoWithConnection, true); + dataContainer.Init(containerDoc); return dataContainer; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoMetadataFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoMetadataFactory.cs index a3d9c345..2793ecb3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoMetadataFactory.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoMetadataFactory.cs @@ -9,6 +9,7 @@ using System.Data.Common; using System.Data.SqlClient; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Utility.SqlScriptFormatters; @@ -60,7 +61,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata } // Connect with SMO and get the metadata for the table - Server server = new Server(new ServerConnection(sqlConn)); + ServerConnection serverConnection; + if (sqlConn.AccessToken == null) + { + serverConnection = new ServerConnection(sqlConn); + } + else + { + serverConnection = new ServerConnection(sqlConn, new AzureAccessToken(sqlConn.AccessToken)); + } + Server server = new Server(serverConnection); Database database = server.Databases[sqlConn.Database]; TableViewTableTypeBase smoResult; switch (objectType.ToLowerInvariant()) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj index 8cc9416d..4d8484cd 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj +++ b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj @@ -21,9 +21,9 @@ - + - + diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs index d2b5318a..2efbe8fe 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptAsScriptingOperation.cs @@ -6,6 +6,7 @@ using System; using System.Collections.Generic; using System.Data.SqlClient; +using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.Utility; using Microsoft.SqlServer.Management.Common; @@ -42,10 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting ServerConnection = serverConnection; } - public ScriptAsScriptingOperation(ScriptingParams parameters) : base(parameters) + public ScriptAsScriptingOperation(ScriptingParams parameters, string azureAccountToken) : base(parameters) { SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString); + if (azureAccountToken != null) + { + sqlConnection.AccessToken = azureAccountToken; + } + ServerConnection = new ServerConnection(sqlConnection); + if (azureAccountToken != null) + { + ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken); + } + disconnectAtDispose = true; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingScriptOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingScriptOperation.cs index 87b8dca6..d7efdac8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingScriptOperation.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingScriptOperation.cs @@ -26,8 +26,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting private int eventSequenceNumber = 1; - public ScriptingScriptOperation(ScriptingParams parameters): base(parameters) + private string azureAccessToken; + + public ScriptingScriptOperation(ScriptingParams parameters, string azureAccessToken): base(parameters) { + this.azureAccessToken = azureAccessToken; } public override void Execute() @@ -200,7 +203,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting selectedObjects.Count(), string.Join(", ", selectedObjects))); - string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString); + string server = GetServerNameFromLiveInstance(this.Parameters.ConnectionString, this.azureAccessToken); string database = new SqlConnectionStringBuilder(this.Parameters.ConnectionString).InitialCatalog; foreach (ScriptingObject scriptingObject in selectedObjects) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs index 2d8b51da..04c896f8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/ScriptingService.cs @@ -111,12 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting // use the owner uri property to lookup its associated ConnectionInfo // and then build a connection string out of that ConnectionInfo connInfo = null; + string accessToken = null; if (parameters.ConnectionString == null) { ScriptingService.ConnectionServiceInstance.TryFindConnection(parameters.OwnerUri, out connInfo); if (connInfo != null) { parameters.ConnectionString = ConnectionService.BuildConnectionString(connInfo.ConnectionDetails); + accessToken = connInfo.ConnectionDetails.AzureAccountToken; } else { @@ -126,11 +128,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting if (!ShouldCreateScriptAsOperation(parameters)) { - operation = new ScriptingScriptOperation(parameters); + operation = new ScriptingScriptOperation(parameters, accessToken); } else { - operation = new ScriptAsScriptingOperation(parameters); + operation = new ScriptAsScriptingOperation(parameters, accessToken); } operation.PlanNotification += (sender, e) => requestContext.SendEvent(ScriptingPlanNotificationEvent.Type, e).Wait(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/Scripting/SmoScriptingOperation.cs b/src/Microsoft.SqlTools.ServiceLayer/Scripting/SmoScriptingOperation.cs index f06f2e26..fee5395d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Scripting/SmoScriptingOperation.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Scripting/SmoScriptingOperation.cs @@ -4,6 +4,7 @@ // using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Scripting.Contracts; using Microsoft.SqlTools.Utility; using System; @@ -71,17 +72,28 @@ namespace Microsoft.SqlTools.ServiceLayer.Scripting parameters.OperationId = this.OperationId; } - protected string GetServerNameFromLiveInstance(string connectionString) + protected string GetServerNameFromLiveInstance(string connectionString, string azureAccessToken) { string serverName = null; using (SqlConnection connection = new SqlConnection(connectionString)) { + if (azureAccessToken != null) + { + connection.AccessToken = azureAccessToken; + } connection.Open(); try { - - ServerConnection serverConnection = new ServerConnection(connection); + ServerConnection serverConnection; + if (azureAccessToken == null) + { + serverConnection = new ServerConnection(connection); + } + else + { + serverConnection = new ServerConnection(connection, new AzureAccessToken(azureAccessToken)); + } serverName = serverConnection.TrueName; } catch (SqlException e) diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs index 6ac2d9b8..18a180fa 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs @@ -258,7 +258,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection RetryPolicy connectionRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); RetryPolicy commandRetryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); - ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy); + ReliableSqlConnection connection = new ReliableSqlConnection(csb.ConnectionString, connectionRetryPolicy, commandRetryPolicy, null); return connection; } @@ -283,7 +283,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection logPath = ReliableConnectionHelper.GetDefaultDatabaseLogPath(conn); }, catchException: null, - useRetry: false); + useRetry: false, + azureAccountToken: null); Assert.False(string.IsNullOrWhiteSpace(filePath)); Assert.False(string.IsNullOrWhiteSpace(logPath)); @@ -342,7 +343,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection var connectionBuilder = CreateTestConnectionStringBuilder(); Assert.NotNull(connectionBuilder); - bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder); + bool isReadOnly = ReliableConnectionHelper.IsDatabaseReadonly(connectionBuilder, null); Assert.False(isReadOnly); } @@ -352,7 +353,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection [Fact] public void TestIsDatabaseReadonlyWithNullBuilder() { - Assert.Throws(() => ReliableConnectionHelper.IsDatabaseReadonly(null)); + Assert.Throws(() => ReliableConnectionHelper.IsDatabaseReadonly(null, null)); } /// @@ -361,7 +362,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection [Fact] public void VerifyAnsiNullAndQuotedIdentifierSettingsReplayed() { - using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true)) + using (ReliableSqlConnection conn = (ReliableSqlConnection) ReliableConnectionHelper.OpenConnection(CreateTestConnectionStringBuilder(), useRetry: true, azureAccountToken: null)) { VerifySessionSettings(conn, true); VerifySessionSettings(conn, false); @@ -506,7 +507,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection "SET NOCOUNT ON; SET NOCOUNT OFF;", ReliableConnectionHelper.SetCommandTimeout, null, - true + true, + null ); Assert.NotNull(result); } @@ -519,7 +521,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection { ReliableConnectionHelper.ServerInfo info = null; var connBuilder = CreateTestConnectionStringBuilder(); - Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info)); + Assert.True(ReliableConnectionHelper.TryGetServerVersion(connBuilder.ConnectionString, out info, null)); Assert.NotNull(info); Assert.NotNull(info.ServerVersion); @@ -535,7 +537,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection RunIfWrapper.RunIfWindows(() => { ReliableConnectionHelper.ServerInfo info = null; - Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info)); + Assert.False(ReliableConnectionHelper.TryGetServerVersion("this is not a valid connstr", out info, null)); }); } @@ -679,12 +681,13 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection var result = LiveConnectionHelper.InitLiveConnectionInfo(null, queryTempFile.FilePath); ConnectionInfo connInfo = result.ConnectionInfo; DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default]; + connection.Open(); Assert.True(connection.State == ConnectionState.Open, "Connection should be open."); Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection)); SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); - Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder)); + Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder, null)); ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder()); Assert.NotNull(ReliableConnectionHelper.GetServerName(connection)); diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj index 7bbebf97..b1b1f5de 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Microsoft.SqlTools.ServiceLayer.IntegrationTests.csproj @@ -33,7 +33,7 @@ - + diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj index acecf75b..a2eeb57c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/Microsoft.SqlTools.ServiceLayer.Test.Common.csproj @@ -12,7 +12,7 @@ - + diff --git a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj index 9293e5ab..95c1bb88 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.TestDriver/Microsoft.SqlTools.ServiceLayer.TestDriver.csproj @@ -12,7 +12,7 @@ - + diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 505a19a8..f1f6219b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -80,7 +80,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection }); var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) .Returns(mockConnection.Object); @@ -146,7 +146,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection .Returns(() => Task.Run(() => {})); var mockFactory = new Mock(); - mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny())) + mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) .Returns(mockConnection.Object) .Returns(mockConnection2.Object); @@ -209,7 +209,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection }); var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) .Returns(mockConnection.Object); @@ -282,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection connectionMock.Setup(c => c.Database).Returns(expectedDbName); var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) .Returns(connectionMock.Object); var connectionService = new ConnectionService(mockFactory.Object); @@ -321,8 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection var dummySqlConnection = new TestSqlConnection(null); var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) - .Returns((string connString) => + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) + .Returns((string connString, string azureAccountToken) => { dummySqlConnection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); @@ -775,7 +775,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection // Setup mock connection factory to inject query results var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) .Returns(CreateMockDbConnection(new[] {data})); var connectionService = new ConnectionService(mockFactory.Object); @@ -1324,8 +1324,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection var connection = new TestSqlConnection(null); var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) - .Returns((string connString) => + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) + .Returns((string connString, string azureAccountToken) => { connection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); @@ -1374,8 +1374,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection var connection = mockConnection.Object; var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) - .Returns((string connString) => + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) + .Returns((string connString, string azureAccountToken) => { connection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); @@ -1427,8 +1427,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection var connection = mockConnection.Object; var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) - .Returns((string connString) => + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) + .Returns((string connString, string azureAccountToken) => { connection.ConnectionString = connString; SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); @@ -1484,5 +1484,32 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.Equal(false, details.TrustServerCertificate); Assert.Equal(30, details.ConnectTimeout); } + + [Fact] + public async void ConnectingWithAzureAccountUsesToken() + { + // Set up mock connection factory + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) + .Returns(new TestSqlConnection(null)); + var connectionService = new ConnectionService(mockFactory.Object); + + var details = TestObjects.GetTestConnectionDetails(); + var azureAccountToken = "testAzureAccountToken"; + details.AzureAccountToken = azureAccountToken; + details.UserName = ""; + details.Password = ""; + details.AuthenticationType = "AzureMFA"; + + // If I open a connection using connection details that include an account token + await connectionService.Connect(new ConnectParams + { + OwnerUri = "testURI", + Connection = details + }); + + // Then the connection factory got called with details including an account token + mockFactory.Verify(factory => factory.CreateSqlConnection(It.IsAny(), It.Is(accountToken => accountToken == azureAccountToken)), Times.Once()); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ReliableConnectionTests.cs new file mode 100644 index 00000000..b4262bbe --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ReliableConnectionTests.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection +{ + /// + /// Tests for ReliableConnection code + /// + public class ReliableConnectionTests + { + [Fact] + public void ReliableSqlConnectionUsesAzureToken() + { + ConnectionDetails details = TestObjects.GetTestConnectionDetails(); + details.UserName = ""; + details.Password = ""; + string connectionString = ConnectionService.BuildConnectionString(details); + string azureAccountToken = "testAzureAccountToken"; + RetryPolicy retryPolicy = RetryPolicyFactory.CreateDefaultConnectionRetryPolicy(); + + // If I create a ReliableSqlConnection using an azure account token + var reliableConnection = new ReliableSqlConnection(connectionString, retryPolicy, retryPolicy, azureAccountToken); + + // Then the connection's azureAccountToken gets set + Assert.Equal(azureAccountToken, reliableConnection.GetUnderlyingConnection().AccessToken); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj index 4ff562a3..0a1def63 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj @@ -14,8 +14,8 @@ - - + + diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs index 49f6ed4c..ff749d1f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs @@ -408,7 +408,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer // Stub out the connection to avoid a 30second timeout while attempting to connect. // The tests don't need any connection context anyhow so this doesn't impact the scenario - mockConnectionOpener.Setup(b => b.OpenSqlConnection(It.IsAny(), It.IsAny())) + mockConnectionOpener.Setup(b => b.OpenServerConnection(It.IsAny(), It.IsAny())) .Throws(); connectionServiceMock.Setup(c => c.Connect(It.IsAny())) .Returns((ConnectParams connectParams) => Task.FromResult(GetCompleteParamsForConnection(connectParams.OwnerUri, details))); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs index af243d6e..5874fb82 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs @@ -173,7 +173,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution private static ISqlConnectionFactory CreateMockFactory(TestResultSet[] data, bool throwOnExecute, bool throwOnRead) { var mockFactory = new Mock(); - mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) .Returns(() => CreateTestConnection(data, throwOnExecute, throwOnRead)); return mockFactory.Object; @@ -184,7 +184,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution // Create a connection info and add the default connection to it ISqlConnectionFactory factory = CreateMockFactory(data, throwOnExecute, throwOnRead); ConnectionInfo ci = new ConnectionInfo(factory, Constants.OwnerUri, StandardConnectionDetails); - ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null); + ci.ConnectionTypeToConnectionMap[ConnectionType.Default] = factory.CreateSqlConnection(null, null); return ci; } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/BatchTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/BatchTests.cs index 4b440ef2..2bd80fc0 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/BatchTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/BatchTests.cs @@ -427,7 +427,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution private static DbConnection GetConnection(ConnectionInfo info) { - return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); + return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null); } [SuppressMessage("ReSharper", "UnusedParameter.Local")] diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ResultSetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ResultSetTests.cs index eb8e6fed..debc0ae3 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ResultSetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ResultSetTests.cs @@ -386,7 +386,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution private static DbDataReader GetReader(TestResultSet[] dataSet, bool throwOnRead, string query) { var info = Common.CreateTestConnectionInfo(dataSet, false, throwOnRead); - var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); + var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null); var command = connection.CreateCommand(); command.CommandText = query; return command.ExecuteReader(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ResultSetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ResultSetTests.cs index 52690adc..f6bd4936 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ResultSetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ResultSetTests.cs @@ -174,7 +174,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults private static DbDataReader GetReader(TestResultSet[] dataSet, string query) { var info = Common.CreateTestConnectionInfo(dataSet, false, false); - var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); + var connection = info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails), null); var command = connection.CreateCommand(); command.CommandText = query; return command.ExecuteReader(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs index 5e00f1f4..ec65ae74 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs @@ -8,6 +8,8 @@ using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices; @@ -273,7 +275,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility /// public class TestSqlConnectionFactory : ISqlConnectionFactory { - public DbConnection CreateSqlConnection(string connectionString) + public DbConnection CreateSqlConnection(string connectionString, string azureAccountToken) { return new TestSqlConnection(null) {