diff --git a/src/ServiceHost/ConnectionServices/ConnectionService.cs b/src/ServiceHost/ConnectionServices/ConnectionService.cs index afbf4ab4..1b62ab50 100644 --- a/src/ServiceHost/ConnectionServices/ConnectionService.cs +++ b/src/ServiceHost/ConnectionServices/ConnectionService.cs @@ -6,12 +6,14 @@ using System; using System.Collections.Generic; using System.Data.SqlClient; +using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.EditorServices.Utility; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts; -namespace Microsoft.SqlTools.ServiceLayer.Connection +namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices { /// /// Main class for the Connection Management services @@ -119,25 +121,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// Open a connection with the specified connection details /// /// - public ConnectionResult Connect(ConnectionDetails connectionDetails) + public async Task Connect(ConnectionDetails connectionDetails) { // build the connection string from the input parameters string connectionString = BuildConnectionString(connectionDetails); // create a sql connection instance - ISqlConnection connection = this.ConnectionFactory.CreateSqlConnection(); + ISqlConnection connection = ConnectionFactory.CreateSqlConnection(connectionString); // open the database - connection.OpenDatabaseConnection(connectionString); + await connection.OpenAsync(); // map the connection id to the connection object for future lookups - this.ActiveConnections.Add(++maxConnectionId, connection); + ActiveConnections.Add(++maxConnectionId, connection); // invoke callback notifications - foreach (var activity in this.onConnectionActivities) - { - activity(connection); - } + var onConnectionCallbackTasks = onConnectionActivities.Select(t => t(connection)); + await Task.WhenAll(onConnectionCallbackTasks); + // TODO: Evaulate if we want to avoid waiting here. We'll need error handling on the other side if we don't wait // return the connection result return new ConnectionResult() @@ -178,7 +179,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection Logger.Write(LogLevel.Verbose, "HandleConnectRequest"); // open connection base on request details - ConnectionResult result = ConnectionService.Instance.Connect(connectionDetails); + ConnectionResult result = await Connect(connectionDetails); await requestContext.SendResult(result); } diff --git a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs index a9a255f3..3ee1cb73 100644 --- a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs +++ b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs @@ -3,21 +3,33 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System.Collections.Generic; +using System.Data; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts { /// /// Interface for the SQL Connection wrapper /// - public interface ISqlConnection + public interface ISqlConnection : IDbConnection { - /// - /// Open a connection to the provided connection string - /// - /// - void OpenDatabaseConnection(string connectionString); + ///// + ///// Open a connection to the provided connection string + ///// + ///// + //void OpenDatabaseConnection(string connectionString); - IEnumerable GetServerObjects(); + //IEnumerable GetServerObjects(); + + string DataSource { get; } + + string ServerVersion { get; } + + void ClearPool(); + + Task OpenAsync(); + + Task OpenAsync(CancellationToken token); } } diff --git a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs index 0a79ec0d..664ca374 100644 --- a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs +++ b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs @@ -13,6 +13,6 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts /// /// Create a new SQL Connection object /// - ISqlConnection CreateSqlConnection(); + ISqlConnection CreateSqlConnection(string connectionString); } } diff --git a/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs b/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs index 937deaa1..ee08af1a 100644 --- a/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs +++ b/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs @@ -3,9 +3,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System.Collections.Generic; +using System; using System.Data; using System.Data.SqlClient; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts { @@ -21,36 +23,161 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts private SqlConnection connection; /// - /// Opens a SqlConnection using provided connection string + /// Creates a new instance of the SqlClientConnection with an underlying connection to the + /// database server provided in . /// - /// - public void OpenDatabaseConnection(string connectionString) + /// Connection string for the database to connect to + public SqlClientConnection(string connectionString) { - this.connection = new SqlConnection(connectionString); - this.connection.Open(); + connection = new SqlConnection(connectionString); } - /// - /// Gets a list of database server schema objects - /// - /// - public IEnumerable GetServerObjects() - { - // Select the values from sys.tables to give a super basic - // autocomplete experience. This will be replaced by SMO. - SqlCommand command = connection.CreateCommand(); - command.CommandText = "SELECT name FROM sys.tables"; - command.CommandTimeout = 15; - command.CommandType = CommandType.Text; - var reader = command.ExecuteReader(); + ///// + ///// Gets a list of database server schema objects + ///// + ///// + //public IEnumerable GetServerObjects() + //{ + // // Select the values from sys.tables to give a super basic + // // autocomplete experience. This will be replaced by SMO. + // SqlCommand command = connection.CreateCommand(); + // command.CommandText = "SELECT name FROM sys.tables"; + // command.CommandTimeout = 15; + // command.CommandType = CommandType.Text; + // var reader = command.ExecuteReader(); - List results = new List(); - while (reader.Read()) + // List results = new List(); + // while (reader.Read()) + // { + // results.Add(reader[0].ToString()); + // } + + // return results; + //} + + #region ISqlConnection Implementation + + #region Properties + + public string ConnectionString + { + get { return connection.ConnectionString; } + set { connection.ConnectionString = value; } + } + + public int ConnectionTimeout + { + get { return connection.ConnectionTimeout; } + } + + public string Database + { + get { return connection.Database; } + } + + public string DataSource + { + get { return connection.DataSource; } + } + + public string ServerVersion + { + get { return connection.ServerVersion; } + } + + public ConnectionState State + { + get { return connection.State; } + } + + #endregion + + #region Public Methods + + public IDbTransaction BeginTransaction() + { + return connection.BeginTransaction(); + } + + public IDbTransaction BeginTransaction(IsolationLevel il) + { + return connection.BeginTransaction(il); + } + + public void ChangeDatabase(string databaseName) + { + connection.ChangeDatabase(databaseName); + } + + public void ClearPool() + { + if (connection != null) { - results.Add(reader[0].ToString()); + SqlConnection.ClearPool(connection); } - - return results; } + + public void Close() + { + connection.Close(); + } + + public IDbCommand CreateCommand() + { + return connection.CreateCommand(); + } + + public void Open() + { + connection.Open(); + } + + public Task OpenAsync() + { + return connection.OpenAsync(); + } + + public Task OpenAsync(CancellationToken token) + { + return connection.OpenAsync(token); + } + + #endregion + + #endregion + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!disposed) + { + if (disposing) + { + if (connection.State == ConnectionState.Open) + { + connection.Close(); + } + connection.Dispose(); + } + disposed = true; + } + } + + ~SqlClientConnection() + { + Dispose(false); + } + + #endregion + } } diff --git a/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs b/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs index 79fbae51..9fbe21f0 100644 --- a/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs +++ b/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs @@ -15,9 +15,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts /// /// Creates a new SqlClientConnection object /// - public ISqlConnection CreateSqlConnection() + public ISqlConnection CreateSqlConnection(string connectionString) { - return new SqlClientConnection(); + return new SqlClientConnection(connectionString); } } } diff --git a/src/ServiceHost/Program.cs b/src/ServiceHost/Program.cs index a3f9cf8b..aa040bad 100644 --- a/src/ServiceHost/Program.cs +++ b/src/ServiceHost/Program.cs @@ -7,7 +7,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.WorkspaceServices; using Microsoft.SqlTools.ServiceLayer.LanguageServices; -using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.ConnectionServices; namespace Microsoft.SqlTools.ServiceLayer {