diff --git a/src/ServiceHost/ConnectionServices/ConnectionService.cs b/src/ServiceHost/ConnectionServices/ConnectionService.cs
index afbf4ab4..1b62ab50 100644
--- a/src/ServiceHost/ConnectionServices/ConnectionService.cs
+++ b/src/ServiceHost/ConnectionServices/ConnectionService.cs
@@ -6,12 +6,14 @@
using System;
using System.Collections.Generic;
using System.Data.SqlClient;
+using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts;
-namespace Microsoft.SqlTools.ServiceLayer.Connection
+namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices
{
///
/// Main class for the Connection Management services
@@ -119,25 +121,24 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// Open a connection with the specified connection details
///
///
- public ConnectionResult Connect(ConnectionDetails connectionDetails)
+ public async Task Connect(ConnectionDetails connectionDetails)
{
// build the connection string from the input parameters
string connectionString = BuildConnectionString(connectionDetails);
// create a sql connection instance
- ISqlConnection connection = this.ConnectionFactory.CreateSqlConnection();
+ ISqlConnection connection = ConnectionFactory.CreateSqlConnection(connectionString);
// open the database
- connection.OpenDatabaseConnection(connectionString);
+ await connection.OpenAsync();
// map the connection id to the connection object for future lookups
- this.ActiveConnections.Add(++maxConnectionId, connection);
+ ActiveConnections.Add(++maxConnectionId, connection);
// invoke callback notifications
- foreach (var activity in this.onConnectionActivities)
- {
- activity(connection);
- }
+ var onConnectionCallbackTasks = onConnectionActivities.Select(t => t(connection));
+ await Task.WhenAll(onConnectionCallbackTasks);
+ // TODO: Evaulate if we want to avoid waiting here. We'll need error handling on the other side if we don't wait
// return the connection result
return new ConnectionResult()
@@ -178,7 +179,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
Logger.Write(LogLevel.Verbose, "HandleConnectRequest");
// open connection base on request details
- ConnectionResult result = ConnectionService.Instance.Connect(connectionDetails);
+ ConnectionResult result = await Connect(connectionDetails);
await requestContext.SendResult(result);
}
diff --git a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs
index a9a255f3..3ee1cb73 100644
--- a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs
+++ b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnection.cs
@@ -3,21 +3,33 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-using System.Collections.Generic;
+using System.Data;
+using System.Threading;
+using System.Threading.Tasks;
namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
{
///
/// Interface for the SQL Connection wrapper
///
- public interface ISqlConnection
+ public interface ISqlConnection : IDbConnection
{
- ///
- /// Open a connection to the provided connection string
- ///
- ///
- void OpenDatabaseConnection(string connectionString);
+ /////
+ ///// Open a connection to the provided connection string
+ /////
+ /////
+ //void OpenDatabaseConnection(string connectionString);
- IEnumerable GetServerObjects();
+ //IEnumerable GetServerObjects();
+
+ string DataSource { get; }
+
+ string ServerVersion { get; }
+
+ void ClearPool();
+
+ Task OpenAsync();
+
+ Task OpenAsync(CancellationToken token);
}
}
diff --git a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs
index 0a79ec0d..664ca374 100644
--- a/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs
+++ b/src/ServiceHost/ConnectionServices/Contracts/ISqlConnectionFactory.cs
@@ -13,6 +13,6 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
///
/// Create a new SQL Connection object
///
- ISqlConnection CreateSqlConnection();
+ ISqlConnection CreateSqlConnection(string connectionString);
}
}
diff --git a/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs b/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs
index 937deaa1..ee08af1a 100644
--- a/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs
+++ b/src/ServiceHost/ConnectionServices/Contracts/SqlConnection.cs
@@ -3,9 +3,11 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-using System.Collections.Generic;
+using System;
using System.Data;
using System.Data.SqlClient;
+using System.Threading;
+using System.Threading.Tasks;
namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
{
@@ -21,36 +23,161 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
private SqlConnection connection;
///
- /// Opens a SqlConnection using provided connection string
+ /// Creates a new instance of the SqlClientConnection with an underlying connection to the
+ /// database server provided in .
///
- ///
- public void OpenDatabaseConnection(string connectionString)
+ /// Connection string for the database to connect to
+ public SqlClientConnection(string connectionString)
{
- this.connection = new SqlConnection(connectionString);
- this.connection.Open();
+ connection = new SqlConnection(connectionString);
}
- ///
- /// Gets a list of database server schema objects
- ///
- ///
- public IEnumerable GetServerObjects()
- {
- // Select the values from sys.tables to give a super basic
- // autocomplete experience. This will be replaced by SMO.
- SqlCommand command = connection.CreateCommand();
- command.CommandText = "SELECT name FROM sys.tables";
- command.CommandTimeout = 15;
- command.CommandType = CommandType.Text;
- var reader = command.ExecuteReader();
+ /////
+ ///// Gets a list of database server schema objects
+ /////
+ /////
+ //public IEnumerable GetServerObjects()
+ //{
+ // // Select the values from sys.tables to give a super basic
+ // // autocomplete experience. This will be replaced by SMO.
+ // SqlCommand command = connection.CreateCommand();
+ // command.CommandText = "SELECT name FROM sys.tables";
+ // command.CommandTimeout = 15;
+ // command.CommandType = CommandType.Text;
+ // var reader = command.ExecuteReader();
- List results = new List();
- while (reader.Read())
+ // List results = new List();
+ // while (reader.Read())
+ // {
+ // results.Add(reader[0].ToString());
+ // }
+
+ // return results;
+ //}
+
+ #region ISqlConnection Implementation
+
+ #region Properties
+
+ public string ConnectionString
+ {
+ get { return connection.ConnectionString; }
+ set { connection.ConnectionString = value; }
+ }
+
+ public int ConnectionTimeout
+ {
+ get { return connection.ConnectionTimeout; }
+ }
+
+ public string Database
+ {
+ get { return connection.Database; }
+ }
+
+ public string DataSource
+ {
+ get { return connection.DataSource; }
+ }
+
+ public string ServerVersion
+ {
+ get { return connection.ServerVersion; }
+ }
+
+ public ConnectionState State
+ {
+ get { return connection.State; }
+ }
+
+ #endregion
+
+ #region Public Methods
+
+ public IDbTransaction BeginTransaction()
+ {
+ return connection.BeginTransaction();
+ }
+
+ public IDbTransaction BeginTransaction(IsolationLevel il)
+ {
+ return connection.BeginTransaction(il);
+ }
+
+ public void ChangeDatabase(string databaseName)
+ {
+ connection.ChangeDatabase(databaseName);
+ }
+
+ public void ClearPool()
+ {
+ if (connection != null)
{
- results.Add(reader[0].ToString());
+ SqlConnection.ClearPool(connection);
}
-
- return results;
}
+
+ public void Close()
+ {
+ connection.Close();
+ }
+
+ public IDbCommand CreateCommand()
+ {
+ return connection.CreateCommand();
+ }
+
+ public void Open()
+ {
+ connection.Open();
+ }
+
+ public Task OpenAsync()
+ {
+ return connection.OpenAsync();
+ }
+
+ public Task OpenAsync(CancellationToken token)
+ {
+ return connection.OpenAsync(token);
+ }
+
+ #endregion
+
+ #endregion
+
+ #region IDisposable Implementation
+
+ private bool disposed;
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ private void Dispose(bool disposing)
+ {
+ if (!disposed)
+ {
+ if (disposing)
+ {
+ if (connection.State == ConnectionState.Open)
+ {
+ connection.Close();
+ }
+ connection.Dispose();
+ }
+ disposed = true;
+ }
+ }
+
+ ~SqlClientConnection()
+ {
+ Dispose(false);
+ }
+
+ #endregion
+
}
}
diff --git a/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs b/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs
index 79fbae51..9fbe21f0 100644
--- a/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs
+++ b/src/ServiceHost/ConnectionServices/Contracts/SqlConnectionFactory.cs
@@ -15,9 +15,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
///
/// Creates a new SqlClientConnection object
///
- public ISqlConnection CreateSqlConnection()
+ public ISqlConnection CreateSqlConnection(string connectionString)
{
- return new SqlClientConnection();
+ return new SqlClientConnection(connectionString);
}
}
}
diff --git a/src/ServiceHost/Program.cs b/src/ServiceHost/Program.cs
index a3f9cf8b..aa040bad 100644
--- a/src/ServiceHost/Program.cs
+++ b/src/ServiceHost/Program.cs
@@ -7,7 +7,7 @@ using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.WorkspaceServices;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
-using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.ConnectionServices;
namespace Microsoft.SqlTools.ServiceLayer
{