Switching from ISqlConnection to DbConnection

This is a fairly minor change that will save tons of time as we develop
this service. The DbConnection and associated Db* abstract classes
ask for synchronous versions of the code and allow the addition of async
code. The SqlClient implementation already implements Db* abstract
classes, so we can piggy back off that for our dependency injection layer.

Tests and existing code has been updated to handle the change, as well
This commit is contained in:
Benjamin Russell
2016-08-02 18:19:51 -07:00
parent f40aa31c67
commit b2f44031b7
8 changed files with 104 additions and 351 deletions

View File

@@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Data.SqlClient;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
@@ -64,15 +65,15 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices
/// <summary>
/// Active connections lazy dictionary instance
/// </summary>
private Lazy<Dictionary<int, ISqlConnection>> activeConnections
= new Lazy<Dictionary<int, ISqlConnection>>(()
=> new Dictionary<int, ISqlConnection>());
private readonly Lazy<Dictionary<int, DbConnection>> activeConnections
= new Lazy<Dictionary<int, DbConnection>>(()
=> new Dictionary<int, DbConnection>());
/// <summary>
/// Callback for onconnection handler
/// </summary>
/// <param name="sqlConnection"></param>
public delegate Task OnConnectionHandler(ISqlConnection sqlConnection);
public delegate Task OnConnectionHandler(DbConnection sqlConnection);
/// <summary>
/// List of onconnection handlers
@@ -82,7 +83,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices
/// <summary>
/// Gets the active connection map
/// </summary>
public Dictionary<int, ISqlConnection> ActiveConnections
public Dictionary<int, DbConnection> ActiveConnections
{
get
{
@@ -128,7 +129,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices
string connectionString = BuildConnectionString(connectionDetails);
// create a sql connection instance
ISqlConnection connection = this.ConnectionFactory.CreateSqlConnection(connectionString);
DbConnection connection = this.ConnectionFactory.CreateSqlConnection(connectionString);
// open the database
connection.Open();

View File

@@ -1,28 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Data;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
{
/// <summary>
/// Interface for the SQL Connection wrapper
/// </summary>
public interface ISqlConnection : IDbConnection
{
string DataSource { get; }
string ServerVersion { get; }
void ClearPool();
Task OpenAsync();
Task OpenAsync(CancellationToken token);
}
}

View File

@@ -3,6 +3,8 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System.Data.Common;
namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
{
/// <summary>
@@ -13,6 +15,6 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
/// <summary>
/// Create a new SQL Connection object
/// </summary>
ISqlConnection CreateSqlConnection(string connectionString);
DbConnection CreateSqlConnection(string connectionString);
}
}

View File

@@ -1,160 +0,0 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Data;
using System.Data.SqlClient;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
{
/// <summary>
/// Wrapper class that implements ISqlConnection and hosts a SqlConnection.
/// This wrapper exists primarily for decoupling to support unit testing.
/// </summary>
public class SqlClientConnection : ISqlConnection
{
/// <summary>
/// the underlying SQL connection
/// </summary>
private SqlConnection connection;
/// <summary>
/// Creates a new instance of the SqlClientConnection with an underlying connection to the
/// database server provided in <paramref name="connectionString"/>.
/// </summary>
/// <param name="connectionString">Connection string for the database to connect to</param>
public SqlClientConnection(string connectionString)
{
connection = new SqlConnection(connectionString);
}
#region ISqlConnection Implementation
#region Properties
public string ConnectionString
{
get { return connection.ConnectionString; }
set { connection.ConnectionString = value; }
}
public int ConnectionTimeout
{
get { return connection.ConnectionTimeout; }
}
public string Database
{
get { return connection.Database; }
}
public string DataSource
{
get { return connection.DataSource; }
}
public string ServerVersion
{
get { return connection.ServerVersion; }
}
public ConnectionState State
{
get { return connection.State; }
}
#endregion
#region Public Methods
public IDbTransaction BeginTransaction()
{
return connection.BeginTransaction();
}
public IDbTransaction BeginTransaction(IsolationLevel il)
{
return connection.BeginTransaction(il);
}
public void ChangeDatabase(string databaseName)
{
connection.ChangeDatabase(databaseName);
}
public void ClearPool()
{
if (connection != null)
{
SqlConnection.ClearPool(connection);
}
}
public void Close()
{
connection.Close();
}
public IDbCommand CreateCommand()
{
return connection.CreateCommand();
}
public void Open()
{
connection.Open();
}
public Task OpenAsync()
{
return connection.OpenAsync();
}
public Task OpenAsync(CancellationToken token)
{
return connection.OpenAsync(token);
}
#endregion
#endregion
#region IDisposable Implementation
private bool disposed;
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
private void Dispose(bool disposing)
{
if (!disposed)
{
if (disposing)
{
if (connection.State == ConnectionState.Open)
{
connection.Close();
}
connection.Dispose();
}
disposed = true;
}
}
~SqlClientConnection()
{
Dispose(false);
}
#endregion
}
}

View File

@@ -3,6 +3,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System.Data.Common;
using System.Data.SqlClient;
namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
{
/// <summary>
@@ -15,9 +18,9 @@ namespace Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts
/// <summary>
/// Creates a new SqlClientConnection object
/// </summary>
public ISqlConnection CreateSqlConnection(string connectionString)
public DbConnection CreateSqlConnection(string connectionString)
{
return new SqlClientConnection(connectionString);
return new SqlConnection(connectionString);
}
}
}

View File

@@ -6,10 +6,9 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;
using System.Data.Common;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.ConnectionServices;
using Microsoft.SqlTools.ServiceLayer.ConnectionServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
using Microsoft.SqlTools.ServiceLayer.WorkspaceServices.Contracts;
@@ -66,16 +65,16 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
/// TODO: Update with refactoring/async
/// </summary>
/// <param name="connection"></param>
public async Task UpdateAutoCompleteCache(ISqlConnection connection)
public async Task UpdateAutoCompleteCache(DbConnection connection)
{
IDbCommand command = connection.CreateCommand();
DbCommand command = connection.CreateCommand();
command.CommandText = "SELECT name FROM sys.tables";
command.CommandTimeout = 15;
command.CommandType = CommandType.Text;
var reader = command.ExecuteReader();
var reader = await command.ExecuteReaderAsync();
List<string> results = new List<string>();
while (reader.Read())
while (await reader.ReadAsync())
{
results.Add(reader[0].ToString());
}

View File

@@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.EditorServices.Utility;
@@ -309,7 +310,7 @@ namespace Microsoft.SqlTools.ServiceLayer.LanguageServices
/// Callback for when a user connection is done processing
/// </summary>
/// <param name="sqlConnection"></param>
public async Task OnConnection(ISqlConnection sqlConnection)
public async Task OnConnection(DbConnection sqlConnection)
{
await AutoCompleteService.Instance.UpdateAutoCompleteCache(sqlConnection);
await Task.FromResult(true);

View File

@@ -9,6 +9,8 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@@ -86,7 +88,7 @@ namespace Microsoft.SqlTools.Test.Utility
}
}
public class TestSqlReader : IDataReader
public class TestDataReader : DbDataReader
{
#region Test Specific Implementations
@@ -105,149 +107,122 @@ namespace Microsoft.SqlTools.Test.Utility
#endregion
public bool GetBoolean(int i)
public override bool GetBoolean(int ordinal)
{
throw new NotImplementedException();
}
public byte GetByte(int i)
public override byte GetByte(int ordinal)
{
throw new NotImplementedException();
}
public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length)
public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length)
{
throw new NotImplementedException();
}
public char GetChar(int i)
public override char GetChar(int ordinal)
{
throw new NotImplementedException();
}
public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length)
public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length)
{
throw new NotImplementedException();
}
public IDataReader GetData(int i)
public override string GetDataTypeName(int ordinal)
{
throw new NotImplementedException();
}
public string GetDataTypeName(int i)
public override DateTime GetDateTime(int ordinal)
{
throw new NotImplementedException();
}
public DateTime GetDateTime(int i)
public override decimal GetDecimal(int ordinal)
{
throw new NotImplementedException();
}
public decimal GetDecimal(int i)
public override double GetDouble(int ordinal)
{
throw new NotImplementedException();
}
public double GetDouble(int i)
public override IEnumerator GetEnumerator()
{
throw new NotImplementedException();
}
public Type GetFieldType(int i)
public override int GetOrdinal(string name)
{
throw new NotImplementedException();
}
public float GetFloat(int i)
public override string GetName(int ordinal)
{
throw new NotImplementedException();
}
public Guid GetGuid(int i)
public override long GetInt64(int ordinal)
{
throw new NotImplementedException();
}
public short GetInt16(int i)
public override int GetInt32(int ordinal)
{
throw new NotImplementedException();
}
public int GetInt32(int i)
public override short GetInt16(int ordinal)
{
throw new NotImplementedException();
}
public long GetInt64(int i)
public override Guid GetGuid(int ordinal)
{
throw new NotImplementedException();
}
public string GetName(int i)
public override float GetFloat(int ordinal)
{
throw new NotImplementedException();
}
public int GetOrdinal(string name)
public override Type GetFieldType(int ordinal)
{
throw new NotImplementedException();
}
public string GetString(int i)
public override string GetString(int ordinal)
{
throw new NotImplementedException();
}
public object GetValue(int i)
public override object GetValue(int ordinal)
{
throw new NotImplementedException();
}
public int GetValues(object[] values)
public override int GetValues(object[] values)
{
throw new NotImplementedException();
}
public bool IsDBNull(int i)
public override bool IsDBNull(int ordinal)
{
throw new NotImplementedException();
}
public int FieldCount { get; }
object IDataRecord.this[string name]
{
get { return tableEnumerator.Current[name]; }
}
object IDataRecord.this[int i]
{
get { return tableEnumerator.Current[tableEnumerator.Current.Keys.ToArray()[i]]; }
}
public void Dispose()
public override bool NextResult()
{
throw new NotImplementedException();
}
public void Close()
{
throw new NotImplementedException();
}
public DataTable GetSchemaTable()
{
throw new NotImplementedException();
}
public bool NextResult()
{
throw new NotImplementedException();
}
public bool Read()
public override bool Read()
{
if (tableEnumerator == null)
{
@@ -263,146 +238,103 @@ namespace Microsoft.SqlTools.Test.Utility
return tableEnumerator.MoveNext();
}
public int Depth { get; }
public bool IsClosed { get; }
public int RecordsAffected { get; }
public override int Depth { get; }
public override bool IsClosed { get; }
public override int RecordsAffected { get; }
public override object this[string name]
{
get { return tableEnumerator.Current[name]; }
}
public override object this[int ordinal]
{
get { return tableEnumerator.Current[tableEnumerator.Current.Keys.ToArray()[ordinal]]; }
}
public override int FieldCount { get; }
public override bool HasRows { get; }
}
/// <summary>
/// Test mock class for IDbCommand
/// </summary>
public class TestSqlCommand : IDbCommand
public class TestSqlCommand : DbCommand
{
public string CommandText { get; set; }
public int CommandTimeout { get; set; }
public CommandType CommandType { get; set; }
public IDbConnection Connection { get; set; }
public IDataParameterCollection Parameters
{
get
{
throw new NotImplementedException();
}
}
public IDbTransaction Transaction { get; set; }
public UpdateRowSource UpdatedRowSource { get; set; }
public void Cancel()
public override void Cancel()
{
throw new NotImplementedException();
}
public IDbDataParameter CreateParameter()
public override int ExecuteNonQuery()
{
throw new NotImplementedException();
}
public void Dispose()
{
}
public int ExecuteNonQuery()
public override object ExecuteScalar()
{
throw new NotImplementedException();
}
public IDataReader ExecuteReader()
{
return new TestSqlReader
{
SqlCommandText = CommandText
};
}
public IDataReader ExecuteReader(CommandBehavior behavior)
public override void Prepare()
{
throw new NotImplementedException();
}
public object ExecuteScalar()
public override string CommandText { get; set; }
public override int CommandTimeout { get; set; }
public override CommandType CommandType { get; set; }
public override UpdateRowSource UpdatedRowSource { get; set; }
protected override DbConnection DbConnection { get; set; }
protected override DbParameterCollection DbParameterCollection { get; }
protected override DbTransaction DbTransaction { get; set; }
public override bool DesignTimeVisible { get; set; }
protected override DbParameter CreateDbParameter()
{
throw new NotImplementedException();
}
public void Prepare()
protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
{
throw new NotImplementedException();
return new TestDataReader {SqlCommandText = CommandText};
}
}
/// <summary>
/// Test mock class for SqlConnection wrapper
/// </summary>
public class TestSqlConnection : ISqlConnection
public class TestSqlConnection : DbConnection
{
public TestSqlConnection(string connectionString)
protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel)
{
throw new NotImplementedException();
}
public void Dispose()
public override void Close()
{
throw new System.NotImplementedException();
throw new NotImplementedException();
}
public IDbTransaction BeginTransaction()
public override void Open()
{
throw new System.NotImplementedException();
// No Op
}
public IDbTransaction BeginTransaction(IsolationLevel il)
public override string ConnectionString { get; set; }
public override string Database { get; }
public override ConnectionState State { get; }
public override string DataSource { get; }
public override string ServerVersion { get; }
protected override DbCommand CreateDbCommand()
{
throw new System.NotImplementedException();
return new TestSqlCommand();
}
public void Close()
public override void ChangeDatabase(string databaseName)
{
throw new System.NotImplementedException();
}
public IDbCommand CreateCommand()
{
return new TestSqlCommand {Connection = this};
}
public void Open()
{
// No Op.
}
public string ConnectionString { get; set; }
public int ConnectionTimeout { get; }
public string Database { get; }
public ConnectionState State { get; }
public void ChangeDatabase(string databaseName)
{
throw new System.NotImplementedException();
}
public string DataSource { get; }
public string ServerVersion { get; }
public void ClearPool()
{
throw new System.NotImplementedException();
}
public async Task OpenAsync()
{
// No Op.
await Task.FromResult(0);
}
public Task OpenAsync(CancellationToken token)
{
throw new System.NotImplementedException();
throw new NotImplementedException();
}
}
@@ -411,9 +343,12 @@ namespace Microsoft.SqlTools.Test.Utility
/// </summary>
public class TestSqlConnectionFactory : ISqlConnectionFactory
{
public ISqlConnection CreateSqlConnection(string connectionString)
public DbConnection CreateSqlConnection(string connectionString)
{
return new TestSqlConnection(connectionString);
return new TestSqlConnection()
{
ConnectionString = connectionString
};
}
}
}