diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/CancelTokenKey.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/CancelTokenKey.cs
new file mode 100644
index 00000000..9976916b
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/CancelTokenKey.cs
@@ -0,0 +1,36 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection
+{
+ ///
+ /// Used to uniquely identify a CancellationTokenSource associated with both
+ /// a string URI and a string connection type.
+ ///
+ public class CancelTokenKey : CancelConnectParams, IEquatable
+ {
+ public override bool Equals(object obj)
+ {
+ CancelTokenKey other = obj as CancelTokenKey;
+ if (other == null)
+ {
+ return false;
+ }
+
+ return other.OwnerUri == OwnerUri && other.Type == Type;
+ }
+
+ public bool Equals(CancelTokenKey obj)
+ {
+ return obj.OwnerUri == OwnerUri && obj.Type == Type;
+ }
+
+ public override int GetHashCode()
+ {
+ return OwnerUri.GetHashCode() ^ Type.GetHashCode();
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
index 5d140b9a..168af4a7 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
@@ -4,8 +4,12 @@
//
using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
using System.Data.Common;
+using System.Linq;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
@@ -23,7 +27,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
OwnerUri = ownerUri;
ConnectionDetails = details;
ConnectionId = Guid.NewGuid();
- IntellisenseMetrics = new InteractionMetrics(new int[] { 50, 100, 200, 500, 1000, 2000 });
+ IntellisenseMetrics = new InteractionMetrics(new int[] {50, 100, 200, 500, 1000, 2000});
}
///
@@ -39,17 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
///
/// Factory used for creating the SQL connection associated with the connection info.
///
- public ISqlConnectionFactory Factory {get; private set;}
+ public ISqlConnectionFactory Factory { get; private set; }
///
/// Properties used for creating/opening the SQL connection.
///
public ConnectionDetails ConnectionDetails { get; private set; }
-
+
///
- /// The connection to the SQL database that commands will be run against.
+ /// A map containing all connections to the database that are associated with
+ /// this ConnectionInfo's OwnerUri.
+ /// This is internal for testing access only
///
- public DbConnection SqlConnection { get; set; }
+ internal readonly ConcurrentDictionary ConnectionTypeToConnectionMap =
+ new ConcurrentDictionary();
///
/// Intellisense Metrics
@@ -60,8 +67,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// Returns true is the db connection is to a SQL db
///
public bool IsAzure { get; set; }
-
- ///
+
/// Returns true if the sql connection is to a DW instance
///
public bool IsSqlDW { get; set; }
@@ -71,5 +77,86 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
///
public int MajorVersion { get; set; }
+ ///
+ /// All DbConnection instances held by this ConnectionInfo
+ ///
+ public ICollection AllConnections
+ {
+ get
+ {
+ return ConnectionTypeToConnectionMap.Values;
+ }
+ }
+
+ ///
+ /// All connection type strings held by this ConnectionInfo
+ ///
+ ///
+ public ICollection AllConnectionTypes
+ {
+ get
+ {
+ return ConnectionTypeToConnectionMap.Keys;
+ }
+ }
+
+ ///
+ /// The count of DbConnectioninstances held by this ConnectionInfo
+ ///
+ public int CountConnections
+ {
+ get
+ {
+ return ConnectionTypeToConnectionMap.Count;
+ }
+ }
+
+ ///
+ /// Try to get the DbConnection associated with the given connection type string.
+ ///
+ /// true if a connection with type connectionType was located and out connection was set,
+ /// false otherwise
+ /// Thrown when connectionType is null or empty
+ public bool TryGetConnection(string connectionType, out DbConnection connection)
+ {
+ Validate.IsNotNullOrEmptyString("Connection Type", connectionType);
+ return ConnectionTypeToConnectionMap.TryGetValue(connectionType, out connection);
+ }
+
+ ///
+ /// Adds a DbConnection to this object and associates it with the given
+ /// connection type string. If a connection already exists with an identical
+ /// connection type string, it is not overwritten. Ignores calls where connectionType = null
+ ///
+ /// Thrown when connectionType is null or empty
+ public void AddConnection(string connectionType, DbConnection connection)
+ {
+ Validate.IsNotNullOrEmptyString("Connection Type", connectionType);
+ ConnectionTypeToConnectionMap.TryAdd(connectionType, connection);
+ }
+
+ ///
+ /// Removes the single DbConnection instance associated with string connectionType
+ ///
+ /// Thrown when connectionType is null or empty
+ public void RemoveConnection(string connectionType)
+ {
+ Validate.IsNotNullOrEmptyString("Connection Type", connectionType);
+ DbConnection connection;
+ ConnectionTypeToConnectionMap.TryRemove(connectionType, out connection);
+ }
+
+ ///
+ /// Removes all DbConnection instances held by this object
+ ///
+ public void RemoveAllConnections()
+ {
+ foreach (var type in AllConnectionTypes)
+ {
+ DbConnection connection;
+ ConnectionTypeToConnectionMap.TryRemove(type, out connection);
+ }
+ }
+
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index f9092c70..2ca48d8c 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -52,7 +52,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
private readonly Dictionary ownerToConnectionMap = new Dictionary();
- private readonly ConcurrentDictionary ownerToCancellationTokenSourceMap = new ConcurrentDictionary();
+ ///
+ /// A map containing all CancellationTokenSource objects that are associated with a given URI/ConnectionType pair.
+ /// Entries in this map correspond to DbConnection instances that are in the process of connecting.
+ ///
+ private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap =
+ new ConcurrentDictionary();
private readonly object cancellationTokenSourceLock = new object();
@@ -120,6 +125,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
return this.connectionFactory;
}
+
+ internal set { this.connectionFactory = value; }
}
///
@@ -138,12 +145,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
///
- /// Open a connection with the specified connection details
+ /// Validates the given ConnectParams object.
///
- ///
- public async Task Connect(ConnectParams connectionParams)
+ /// The params to validate
+ /// A ConnectionCompleteParams object upon validation error,
+ /// null upon validation success
+ public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams)
{
- // Validate parameters
string paramValidationErrorMessage;
if (connectionParams == null)
{
@@ -161,29 +169,139 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
};
}
- // Resolve if it is an existing connection
- // Disconnect active connection if the URI is already connected
+ // return null upon success
+ return null;
+ }
+
+ ///
+ /// Open a connection with the specified ConnectParams
+ ///
+ public async Task Connect(ConnectParams connectionParams)
+ {
+ // Validate parameters
+ ConnectionCompleteParams validationResults = ValidateConnectParams(connectionParams);
+ if (validationResults != null)
+ {
+ return validationResults;
+ }
+
+ // If there is no ConnectionInfo in the map, create a new ConnectionInfo,
+ // but wait until later when we are connected to add it to the map.
ConnectionInfo connectionInfo;
- if (ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo) )
+ if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo))
+ {
+ connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
+ }
+
+ // Resolve if it is an existing connection
+ // Disconnect active connection if the URI is already connected for this connection type
+ DbConnection existingConnection;
+ if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection))
{
var disconnectParams = new DisconnectParams()
{
- OwnerUri = connectionParams.OwnerUri
+ OwnerUri = connectionParams.OwnerUri,
+ Type = connectionParams.Type
};
Disconnect(disconnectParams);
}
- connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
- // try to connect
- var response = new ConnectionCompleteParams {OwnerUri = connectionParams.OwnerUri};
+ // Try to open a connection with the given ConnectParams
+ ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams);
+ if (response != null)
+ {
+ return response;
+ }
+
+ // If this is the first connection for this URI, add the ConnectionInfo to the map
+ if (!ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri))
+ {
+ ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
+ }
+
+ // Invoke callback notifications
+ InvokeOnConnectionActivities(connectionInfo, connectionParams);
+
+ // Return information about the connected SQL Server instance
+ return GetConnectionCompleteParams(connectionParams.Type, connectionInfo);
+ }
+
+ ///
+ /// Creates a ConnectionCompleteParams as a response to a successful connection.
+ /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo.
+ ///
+ /// A ConnectionCompleteParams in response to the successful connection
+ private ConnectionCompleteParams GetConnectionCompleteParams(string connectionType, ConnectionInfo connectionInfo)
+ {
+ ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionType };
+
+ try
+ {
+ DbConnection connection;
+ connectionInfo.TryGetConnection(connectionType, out connection);
+
+ // Update with the actual database name in connectionInfo and result
+ // Doing this here as we know the connection is open - expect to do this only on connecting
+ connectionInfo.ConnectionDetails.DatabaseName = connection.Database;
+ response.ConnectionSummary = new ConnectionSummary
+ {
+ ServerName = connectionInfo.ConnectionDetails.ServerName,
+ DatabaseName = connectionInfo.ConnectionDetails.DatabaseName,
+ UserName = connectionInfo.ConnectionDetails.UserName,
+ };
+
+ response.ConnectionId = connectionInfo.ConnectionId.ToString();
+
+ var reliableConnection = connection as ReliableSqlConnection;
+ DbConnection underlyingConnection = reliableConnection != null
+ ? reliableConnection.GetUnderlyingConnection()
+ : connection;
+
+ ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(underlyingConnection);
+ response.ServerInfo = new ServerInfo
+ {
+ ServerMajorVersion = serverInfo.ServerMajorVersion,
+ ServerMinorVersion = serverInfo.ServerMinorVersion,
+ ServerReleaseVersion = serverInfo.ServerReleaseVersion,
+ EngineEditionId = serverInfo.EngineEditionId,
+ ServerVersion = serverInfo.ServerVersion,
+ ServerLevel = serverInfo.ServerLevel,
+ ServerEdition = serverInfo.ServerEdition,
+ IsCloud = serverInfo.IsCloud,
+ AzureVersion = serverInfo.AzureVersion,
+ OsVersion = serverInfo.OsVersion
+ };
+ connectionInfo.IsAzure = serverInfo.IsCloud;
+ connectionInfo.MajorVersion = serverInfo.ServerMajorVersion;
+ connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
+ }
+ catch (Exception ex)
+ {
+ response.Messages = ex.ToString();
+ }
+
+ return response;
+ }
+
+ ///
+ /// Tries to create and open a connection with the given ConnectParams.
+ ///
+ /// null upon success, a ConnectionCompleteParams detailing the error upon failure
+ private async Task TryOpenConnection(ConnectionInfo connectionInfo, ConnectParams connectionParams)
+ {
CancellationTokenSource source = null;
+ DbConnection connection = null;
+ CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type };
+ ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type };
+
try
{
// build the connection string from the input parameters
string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
// create a sql connection instance
- connectionInfo.SqlConnection = connectionInfo.Factory.CreateSqlConnection(connectionString);
+ connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
+ connectionInfo.AddConnection(connectionParams.Type, connection);
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
using (source = new CancellationTokenSource())
@@ -193,11 +311,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
// If the URI is currently connecting from a different request, cancel it before we try to connect
CancellationTokenSource currentSource;
- if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out currentSource))
+ if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource))
{
currentSource.Cancel();
}
- ownerToCancellationTokenSourceMap[connectionParams.OwnerUri] = source;
+ cancelTupleToCancellationTokenSourceMap[cancelKey] = source;
}
// Create a task to handle cancellation requests
@@ -214,10 +332,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
});
- var openTask = Task.Run(async () => {
- await connectionInfo.SqlConnection.OpenAsync(source.Token);
+ var openTask = Task.Run(async () =>
+ {
+ await connection.OpenAsync(source.Token);
});
-
+
// Open the connection
await Task.WhenAny(openTask, cancellationTask).Unwrap();
source.Cancel();
@@ -250,60 +369,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
// Only remove the token from the map if it is the same one created by this request
CancellationTokenSource sourceValue;
- if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out sourceValue) && sourceValue == source)
+ if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out sourceValue) && sourceValue == source)
{
- ownerToCancellationTokenSourceMap.TryRemove(connectionParams.OwnerUri, out sourceValue);
+ cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue);
}
}
}
- ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
+ // Return null upon success
+ return null;
+ }
- // Update with the actual database name in connectionInfo and result
- // Doing this here as we know the connection is open - expect to do this only on connecting
- connectionInfo.ConnectionDetails.DatabaseName = connectionInfo.SqlConnection.Database;
- response.ConnectionSummary = new ConnectionSummary
+ ///
+ /// Gets the existing connection with the given URI and connection type string. If none exists,
+ /// creates a new connection. This cannot be used to create a default connection or to create a
+ /// connection if a default connection does not exist.
+ ///
+ public async Task GetOrOpenConnection(string ownerUri, string connectionType)
+ {
+ if (string.IsNullOrEmpty(ownerUri) || string.IsNullOrEmpty(connectionType))
{
- ServerName = connectionInfo.ConnectionDetails.ServerName,
- DatabaseName = connectionInfo.ConnectionDetails.DatabaseName,
- UserName = connectionInfo.ConnectionDetails.UserName,
- };
+ return null;
+ }
- // invoke callback notifications
- InvokeOnConnectionActivities(connectionInfo);
-
- // try to get information about the connected SQL Server instance
- try
+ // Try to get the ConnectionInfo, if it exists
+ ConnectionInfo connectionInfo = ownerToConnectionMap[ownerUri];
+ if (connectionInfo == null)
{
- var reliableConnection = connectionInfo.SqlConnection as ReliableSqlConnection;
- DbConnection connection = reliableConnection != null ? reliableConnection.GetUnderlyingConnection() : connectionInfo.SqlConnection;
-
- ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(connection);
- response.ServerInfo = new ServerInfo
+ return null;
+ }
+
+ // Make sure a default connection exists
+ DbConnection defaultConnection;
+ if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection))
+ {
+ return null;
+ }
+
+ // Try to get the DbConnection
+ DbConnection connection;
+ if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType)
+ {
+ // If the DbConnection does not exist and is not the default connection, create one.
+ // We can't create the default (initial) connection here because we won't have a ConnectionDetails
+ // if Connect() has not yet been called.
+ ConnectParams connectParams = new ConnectParams()
{
- ServerMajorVersion = serverInfo.ServerMajorVersion,
- ServerMinorVersion = serverInfo.ServerMinorVersion,
- ServerReleaseVersion = serverInfo.ServerReleaseVersion,
- EngineEditionId = serverInfo.EngineEditionId,
- ServerVersion = serverInfo.ServerVersion,
- ServerLevel = serverInfo.ServerLevel,
- ServerEdition = serverInfo.ServerEdition,
- IsCloud = serverInfo.IsCloud,
- AzureVersion = serverInfo.AzureVersion,
- OsVersion = serverInfo.OsVersion
+ OwnerUri = ownerUri,
+ Connection = connectionInfo.ConnectionDetails,
+ Type = connectionType
};
- connectionInfo.IsAzure = serverInfo.IsCloud;
- connectionInfo.MajorVersion = serverInfo.ServerMajorVersion;
- connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
- }
- catch(Exception ex)
- {
- response.Messages = ex.ToString();
+ await Connect(connectParams);
+ connectionInfo.TryGetConnection(connectionType, out connection);
}
- // return the connection result
- response.ConnectionId = connectionInfo.ConnectionId.ToString();
- return response;
+ return connection;
}
///
@@ -317,9 +437,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return false;
}
+ CancelTokenKey cancelKey = new CancelTokenKey
+ {
+ OwnerUri = cancelParams.OwnerUri,
+ Type = cancelParams.Type
+ };
+
// Cancel any current connection attempts for this URI
CancellationTokenSource source;
- if (ownerToCancellationTokenSourceMap.TryGetValue(cancelParams.OwnerUri, out source))
+ if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out source))
{
try
{
@@ -331,10 +457,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
return false;
}
}
- else
- {
- return false;
- }
+
+ return false;
}
///
@@ -349,58 +473,128 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
// Cancel if we are in the middle of connecting
- if (CancelConnect(new CancelConnectParams() { OwnerUri = disconnectParams.OwnerUri }))
+ if (CancelConnections(disconnectParams.OwnerUri, disconnectParams.Type))
{
return false;
}
- // Lookup the connection owned by the URI
+ // Lookup the ConnectionInfo owned by the URI
ConnectionInfo info;
if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info))
{
return false;
}
- if (ServiceHost != null)
+ // Call Close() on the connections we want to disconnect
+ // If no connections were located, return false
+ if (!CloseConnections(info, disconnectParams.Type))
+ {
+ return false;
+ }
+
+ // Remove the disconnected connections from the ConnectionInfo map
+ if (disconnectParams.Type == null)
+ {
+ info.RemoveAllConnections();
+ }
+ else
+ {
+ info.RemoveConnection(disconnectParams.Type);
+ }
+
+ // If the ConnectionInfo has no more connections, remove the ConnectionInfo
+ if (info.CountConnections == 0)
+ {
+ ownerToConnectionMap.Remove(disconnectParams.OwnerUri);
+ }
+
+ // Handle Telemetry disconnect events if we are disconnecting the default connection
+ if (disconnectParams.Type == null || disconnectParams.Type == ConnectionType.Default)
+ {
+ HandleDisconnectTelemetry(info);
+ InvokeOnDisconnectionActivities(info);
+ }
+
+ // Return true upon success
+ return true;
+ }
+
+ ///
+ /// Cancel connections associated with the given ownerUri.
+ /// If connectionType is not null, cancel the connection with the given connectionType
+ /// If connectionType is null, cancel all pending connections associated with ownerUri.
+ ///
+ /// true if a single pending connection associated with the non-null connectionType was
+ /// found and cancelled, false otherwise
+ private bool CancelConnections(string ownerUri, string connectionType)
+ {
+ // Cancel the connection of the given type
+ if (connectionType != null)
+ {
+ // If we are trying to disconnect a specific connection and it was just cancelled,
+ // this will return true
+ return CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = connectionType });
+ }
+
+ // Cancel all pending connections
+ foreach (var entry in cancelTupleToCancellationTokenSourceMap)
+ {
+ string entryConnectionUri = entry.Key.OwnerUri;
+ string entryConnectionType = entry.Key.Type;
+ if (ownerUri.Equals(entryConnectionUri))
+ {
+ CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = entryConnectionType });
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Closes DbConnections associated with the given ConnectionInfo.
+ /// If connectionType is not null, closes the DbConnection with the type given by connectionType.
+ /// If connectionType is null, closes all DbConnections.
+ ///
+ /// true if connections were found and attempted to be closed,
+ /// false if no connections were found
+ private bool CloseConnections(ConnectionInfo connectionInfo, string connectionType)
+ {
+ ICollection connectionsToDisconnect = new List();
+ if (connectionType == null)
+ {
+ connectionsToDisconnect = connectionInfo.AllConnections;
+ }
+ else
+ {
+ // Make sure there is an existing connection of this type
+ DbConnection connection;
+ if (!connectionInfo.TryGetConnection(connectionType, out connection))
+ {
+ return false;
+ }
+ connectionsToDisconnect.Add(connection);
+ }
+
+ if (connectionsToDisconnect.Count == 0)
+ {
+ return false;
+ }
+
+ foreach (DbConnection connection in connectionsToDisconnect)
{
try
{
- // Send a telemetry notification for intellisense performance metrics
- ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams()
- {
- Params = new TelemetryProperties
- {
- Properties = new Dictionary
- {
- { "IsAzure", info.IsAzure ? "1" : "0" }
- },
- EventName = TelemetryEventNames.IntellisenseQuantile,
- Measures = info.IntellisenseMetrics.Quantile
- }
- });
+ connection.Close();
}
- catch (Exception ex)
+ catch (Exception)
{
- Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString());
+ // Ignore
}
}
- // Close the connection
- info.SqlConnection.Close();
-
- // Remove URI mapping
- ownerToConnectionMap.Remove(disconnectParams.OwnerUri);
-
- // Invoke callback notifications
- foreach (var activity in this.onDisconnectActivities)
- {
- activity(info.ConnectionDetails, disconnectParams.OwnerUri);
- }
-
- // Success
return true;
}
-
+
///
/// List all databases on the server specified
///
@@ -486,11 +680,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
onDisconnectActivities.Add(activity);
}
-
+
///
/// Handle new connection requests
///
- ///
+ ///
///
///
protected async Task HandleConnectRequest(
@@ -517,8 +711,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
try
{
+ // result is null if the ConnectParams was successfully validated
+ ConnectionCompleteParams result = ValidateConnectParams(connectParams);
+ if (result != null)
+ {
+ await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
+ return;
+ }
+
// open connection based on request details
- ConnectionCompleteParams result = await Instance.Connect(connectParams);
+ result = await Instance.Connect(connectParams);
await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
}
catch (Exception ex)
@@ -744,10 +946,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
try
{
- if (info.SqlConnection.State == ConnectionState.Open)
+ foreach (DbConnection connection in info.AllConnections)
{
- info.SqlConnection.ChangeDatabase(newDatabaseName);
+ if (connection.State == ConnectionState.Open)
+ {
+ connection.ChangeDatabase(newDatabaseName);
+ }
}
+
info.ConnectionDetails.DatabaseName = newDatabaseName;
// Fire a connection changed event
@@ -771,13 +977,65 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
}
- private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo)
+ ///
+ /// Invokes the initial on-connect activities if the provided ConnectParams represents the default
+ /// connection.
+ ///
+ private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams)
{
+ if (connectParams.Type != ConnectionType.Default)
+ {
+ return;
+ }
+
foreach (var activity in this.onConnectionActivities)
{
// not awaiting here to allow handlers to run in the background
activity(connectionInfo);
}
}
+
+ ///
+ /// Invokes the fianl on-disconnect activities if the provided DisconnectParams represents the default
+ /// connection or is null - representing that all connections are being disconnected.
+ ///
+ private void InvokeOnDisconnectionActivities(ConnectionInfo connectionInfo)
+ {
+ foreach (var activity in this.onDisconnectActivities)
+ {
+ activity(connectionInfo.ConnectionDetails, connectionInfo.OwnerUri);
+ }
+ }
+
+ ///
+ /// Handles the Telemetry events that occur upon disconnect.
+ ///
+ ///
+ private void HandleDisconnectTelemetry(ConnectionInfo connectionInfo)
+ {
+ if (ServiceHost != null)
+ {
+ try
+ {
+ // Send a telemetry notification for intellisense performance metrics
+ ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams()
+ {
+ Params = new TelemetryProperties
+ {
+ Properties = new Dictionary
+ {
+ {"IsAzure", connectionInfo.IsAzure ? "1" : "0"}
+ },
+ EventName = TelemetryEventNames.IntellisenseQuantile,
+ Measures = connectionInfo.IntellisenseMetrics.Quantile
+ }
+ });
+ }
+ catch (Exception ex)
+ {
+ Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString());
+ }
+ }
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs
new file mode 100644
index 00000000..aabbb7d9
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs
@@ -0,0 +1,19 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection
+{
+ ///
+ /// String constants that represent connection types.
+ ///
+ /// Default: Connection used by the editor. Opened by the editor upon the initial connection.
+ /// Query: Connection used for executing queries. Opened when the first query is executed.
+ ///
+ public static class ConnectionType
+ {
+ public const string Default = "Default";
+ public const string Query = "Query";
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs
index 9f2efdb0..9c76e590 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs
@@ -14,6 +14,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// A URI identifying the owner of the connection. This will most commonly be a file in the workspace
/// or a virtual file representing an object in a database.
///
- public string OwnerUri { get; set; }
+ public string OwnerUri { get; set; }
+
+ ///
+ /// The type of connection we are trying to cancel
+ ///
+ public string Type { get; set; } = ConnectionType.Default;
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
index 31dad8c5..58d38ee4 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
@@ -22,5 +22,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// connection properties to the same database.
///
public ConnectionDetails Connection { get; set; }
+
+ ///
+ /// The type of this connection. By default, this is set to ConnectionType.Default.
+ ///
+ public string Type { get; set; } = ConnectionType.Default;
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs
index 50517a52..0203a110 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs
@@ -47,6 +47,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// Gets or sets the actual Connection established, including Database Name
///
public ConnectionSummary ConnectionSummary { get; set; }
+
+ ///
+ /// The type of connection that this notification is for
+ ///
+ public string Type { get; set; } = ConnectionType.Default;
}
///
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs
index 91bc7faf..d7e5de72 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs
@@ -15,5 +15,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// or a virtual file representing an object in a database.
///
public string OwnerUri { get; set; }
+
+ ///
+ /// The type of connection we are disconnecting. If null, we will disconnect all connections.
+ /// connections.
+ ///
+ public string Type { get; set; }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
index 48b09b20..f9d2ad10 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
@@ -15,6 +15,7 @@ using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Utility;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using System.Collections.Generic;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
@@ -369,82 +370,62 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
return;
}
- // Open up a connection for querying the database
- string connectionString = ConnectionService.BuildConnectionString(editorConnection.ConnectionDetails);
- // TODO: Don't create a new connection every time, see TFS #834978
- using (DbConnection conn = editorConnection.Factory.CreateSqlConnection(connectionString))
+ // Locate and setup the connection
+ DbConnection queryConnection = await ConnectionService.Instance.GetOrOpenConnection(editorConnection.OwnerUri, ConnectionType.Query);
+ ReliableSqlConnection sqlConn = queryConnection as ReliableSqlConnection;
+ if (sqlConn != null)
{
- try
+ // Subscribe to database informational messages
+ sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage;
+ }
+
+ try
+ {
+ // Execute beforeBatches synchronously, before the user defined batches
+ foreach (Batch b in BeforeBatches)
{
- await conn.OpenAsync();
- }
- catch (Exception exception)
- {
- this.HasExecuted = true;
- if (QueryConnectionException != null)
- {
- await QueryConnectionException(exception.Message);
- }
- return;
+ await b.Execute(queryConnection, cancellationSource.Token);
}
- ReliableSqlConnection sqlConn = conn as ReliableSqlConnection;
+ // We need these to execute synchronously, otherwise the user will be very unhappy
+ foreach (Batch b in Batches)
+ {
+ // Add completion callbacks
+ b.BatchStart += BatchStarted;
+ b.BatchCompletion += BatchCompleted;
+ b.BatchMessageSent += BatchMessageSent;
+ b.ResultSetCompletion += ResultSetCompleted;
+ await b.Execute(queryConnection, cancellationSource.Token);
+ }
+
+ // Execute afterBatches synchronously, after the user defined batches
+ foreach (Batch b in AfterBatches)
+ {
+ await b.Execute(queryConnection, cancellationSource.Token);
+ }
+
+ // Call the query execution callback
+ if (QueryCompleted != null)
+ {
+ await QueryCompleted(this);
+ }
+ }
+ catch (Exception)
+ {
+ // Call the query failure callback
+ if (QueryFailed != null)
+ {
+ await QueryFailed(this);
+ }
+ }
+ finally
+ {
if (sqlConn != null)
{
// Subscribe to database informational messages
- sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage;
+ sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage;
}
-
- try
- {
- // Execute beforeBatches synchronously, before the user defined batches
- foreach (Batch b in BeforeBatches)
- {
- await b.Execute(conn, cancellationSource.Token);
- }
-
- // We need these to execute synchronously, otherwise the user will be very unhappy
- foreach (Batch b in Batches)
- {
- // Add completion callbacks
- b.BatchStart += BatchStarted;
- b.BatchCompletion += BatchCompleted;
- b.BatchMessageSent += BatchMessageSent;
- b.ResultSetCompletion += ResultSetCompleted;
- await b.Execute(conn, cancellationSource.Token);
- }
-
- // Execute afterBatches synchronously, after the user defined batches
- foreach (Batch b in AfterBatches)
- {
- await b.Execute(conn, cancellationSource.Token);
- }
-
- // Call the query execution callback
- if (QueryCompleted != null)
- {
- await QueryCompleted(this);
- }
- }
- catch (Exception)
- {
- // Call the query failure callback
- if (QueryFailed != null)
- {
- await QueryFailed(this);
- }
- }
- finally
- {
- if (sqlConn != null)
- {
- // Subscribe to database informational messages
- sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage;
- }
- }
-
- // TODO: Close connection after eliminating using statement for above TODO
- }
+ }
}
///
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ConnectionServiceTests.cs
new file mode 100644
index 00000000..b78a2770
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ConnectionServiceTests.cs
@@ -0,0 +1,104 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.Collections.Generic;
+using System.Data.Common;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using Microsoft.SqlTools.Test.Utility;
+using Xunit;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution;
+using Microsoft.SqlTools.ServiceLayer.SqlContext;
+using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution;
+using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
+{
+ ///
+ /// Tests for the ServiceHost Connection Service tests that require a live database connection
+ ///
+ public class ConnectionServiceTests
+ {
+ [Fact]
+ public async Task RunningMultipleQueriesCreatesOnlyOneConnection()
+ {
+ // Connect/disconnect twice to ensure reconnection can occur
+ ConnectionService service = ConnectionService.Instance;
+ service.OwnerToConnectionMap.Clear();
+ for (int i = 0; i < 2; i++)
+ {
+ var result = await TestObjects.InitLiveConnectionInfo();
+ ConnectionInfo connectionInfo = result.ConnectionInfo;
+ string uri = connectionInfo.OwnerUri;
+
+ // We should see one ConnectionInfo and one DbConnection
+ Assert.Equal(1, connectionInfo.CountConnections);
+ Assert.Equal(1, service.OwnerToConnectionMap.Count);
+
+ // If we run a query
+ var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
+ Query query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
+ query.Execute();
+ query.ExecutionTask.Wait();
+
+ // We should see two DbConnections
+ Assert.Equal(2, connectionInfo.CountConnections);
+
+ // If we run another query
+ query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
+ query.Execute();
+ query.ExecutionTask.Wait();
+
+ // We should still have 2 DbConnections
+ Assert.Equal(2, connectionInfo.CountConnections);
+
+ // If we disconnect, we should remain in a consistent state to do it over again
+ // e.g. loop and do it over again
+ service.Disconnect(new DisconnectParams() { OwnerUri = connectionInfo.OwnerUri });
+
+ // We should be left with an empty connection map
+ Assert.Equal(0, service.OwnerToConnectionMap.Count);
+ }
+ }
+
+ [Fact]
+ public async Task DatabaseChangesAffectAllConnections()
+ {
+ // If we make a connection to a live database
+ ConnectionService service = ConnectionService.Instance;
+ var result = await TestObjects.InitLiveConnectionInfo();
+ ConnectionInfo connectionInfo = result.ConnectionInfo;
+ ConnectionDetails details = connectionInfo.ConnectionDetails;
+ string uri = connectionInfo.OwnerUri;
+ string initialDatabaseName = details.DatabaseName;
+ string newDatabaseName = "tempdb";
+ string changeDatabaseQuery = "use " + newDatabaseName;
+
+ // Then run any query to create a query DbConnection
+ var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
+ Query query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
+ query.Execute();
+ query.ExecutionTask.Wait();
+
+ // All open DbConnections (Query and Default) should have initialDatabaseName as their database
+ foreach (DbConnection connection in connectionInfo.AllConnections)
+ {
+ Assert.Equal(connection.Database, initialDatabaseName);
+ }
+
+ // If we run a query to change the database
+ query = new Query(changeDatabaseQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
+ query.Execute();
+ query.ExecutionTask.Wait();
+
+ // All open DbConnections (Query and Default) should have newDatabaseName as their database
+ foreach (DbConnection connection in connectionInfo.AllConnections)
+ {
+ Assert.Equal(connection.Database, newDatabaseName);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs
index bd74e613..6abd18d5 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs
@@ -685,19 +685,21 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
{
var result = await TestObjects.InitLiveConnectionInfo();
ConnectionInfo connInfo = result.ConnectionInfo;
+ DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default];
- Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connInfo.SqlConnection));
+
+ Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection));
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder();
Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder));
ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder());
- Assert.NotNull(ReliableConnectionHelper.GetServerName(connInfo.SqlConnection));
- Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connInfo.SqlConnection));
+ Assert.NotNull(ReliableConnectionHelper.GetServerName(connection));
+ Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connection));
- Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connInfo.SqlConnection));
+ Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connection));
- ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connInfo.SqlConnection);
+ ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connection);
Assert.NotNull(ReliableConnectionHelper.IsVersionGreaterThan2012RTM(info));
}
@@ -728,8 +730,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection
{
var result = await TestObjects.InitLiveConnectionInfo();
ConnectionInfo connInfo = result.ConnectionInfo;
+ DbConnection dbConnection;
+ connInfo.TryGetConnection(ConnectionType.Default, out dbConnection);
- var connection = connInfo.SqlConnection as ReliableSqlConnection;
+ var connection = dbConnection as ReliableSqlConnection;
var command = new ReliableSqlConnection.ReliableSqlCommand(connection);
Assert.NotNull(command.Connection);
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs
index 4bb145a6..45c38aa6 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs
@@ -19,8 +19,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution.DataSt
private async Task GetTestStorageDataReader(string query)
{
var result = await TestObjects.InitLiveConnectionInfo();
+ DbConnection connection;
+ result.ConnectionInfo.TryGetConnection(ConnectionType.Default, out connection);
- var command = result.ConnectionInfo.SqlConnection.CreateCommand();
+ var command = connection.CreateCommand();
command.CommandText = query;
DbDataReader reader = command.ExecuteReader();
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/ExecuteTests.cs
new file mode 100644
index 00000000..5b25a9d5
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/ExecuteTests.cs
@@ -0,0 +1,112 @@
+using System;
+using System.Collections.Generic;
+using System.Data.Common;
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
+using Microsoft.SqlTools.ServiceLayer.SqlContext;
+using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution;
+using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
+using Microsoft.SqlTools.Test.Utility;
+using Xunit;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution
+{
+ public class ExecuteTests
+ {
+ [Fact]
+ public async Task RollbackTransactionFailsWithoutBeginTransaction()
+ {
+ const string refactorText = "ROLLBACK TRANSACTION";
+
+ // Given a connection to a live database
+ var result = await TestObjects.InitLiveConnectionInfo();
+ ConnectionInfo connInfo = result.ConnectionInfo;
+ var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
+
+ // If I run a "ROLLBACK TRANSACTION" query
+ Query query = new Query(refactorText, connInfo, new QueryExecutionSettings(), fileStreamFactory);
+ query.Execute();
+ query.ExecutionTask.Wait();
+
+ // There should be an error
+ Assert.True(query.Batches[0].HasError);
+ }
+
+ [Fact]
+ public async Task TransactionsSucceedAcrossQueries()
+ {
+ const string beginText = "BEGIN TRANSACTION";
+ const string rollbackText = "ROLLBACK TRANSACTION";
+
+ // Given a connection to a live database
+ var result = await TestObjects.InitLiveConnectionInfo();
+ ConnectionInfo connInfo = result.ConnectionInfo;
+ var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
+
+ // If I run a "BEGIN TRANSACTION" query
+ CreateAndExecuteQuery(beginText, connInfo, fileStreamFactory);
+
+ // Then I run a "ROLLBACK TRANSACTION" query, there should be no errors
+ Query rollbackQuery = CreateAndExecuteQuery(rollbackText, connInfo, fileStreamFactory);
+ Assert.False(rollbackQuery.Batches[0].HasError);
+ }
+
+ [Fact]
+ public async Task TempTablesPersistAcrossQueries()
+ {
+ const string createTempText = "CREATE TABLE #someTempTable (id int)";
+ const string insertTempText = "INSERT INTO #someTempTable VALUES(1)";
+
+ // Given a connection to a live database
+ var result = await TestObjects.InitLiveConnectionInfo();
+ ConnectionInfo connInfo = result.ConnectionInfo;
+ var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
+
+ // If I run a query creating a temp table
+ CreateAndExecuteQuery(createTempText, connInfo, fileStreamFactory);
+
+ // Then I run a different query using that temp table, there should be no errors
+ Query insertTempQuery = CreateAndExecuteQuery(insertTempText, connInfo, fileStreamFactory);
+ Assert.False(insertTempQuery.Batches[0].HasError);
+ }
+
+ [Fact]
+ public async Task DatabaseChangesWhenCallingUseDatabase()
+ {
+ const string master = "master";
+ const string tempdb = "tempdb";
+ const string useQuery = "USE {0}";
+
+ // Given a connection to a live database
+ var result = await TestObjects.InitLiveConnectionInfo();
+ ConnectionInfo connInfo = result.ConnectionInfo;
+ DbConnection connection;
+ connInfo.TryGetConnection(ConnectionType.Default, out connection);
+
+ var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
+
+ // If I use master, the current database should be master
+ CreateAndExecuteQuery(string.Format(useQuery, master), connInfo, fileStreamFactory);
+ Assert.Equal(master, connection.Database);
+
+ // If I use tempdb, the current database should be tempdb
+ CreateAndExecuteQuery(string.Format(useQuery, tempdb), connInfo, fileStreamFactory);
+ Assert.Equal(tempdb, connection.Database);
+
+ // If I switch back to master, the current database should be master
+ CreateAndExecuteQuery(string.Format(useQuery, master), connInfo, fileStreamFactory);
+ Assert.Equal(master, connection.Database);
+ }
+
+ public Query CreateAndExecuteQuery(string queryText, ConnectionInfo connectionInfo, IFileStreamFactory fileStreamFactory)
+ {
+ Query query = new Query(queryText, connectionInfo, new QueryExecutionSettings(), fileStreamFactory);
+ query.Execute();
+ query.ExecutionTask.Wait();
+ return query;
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs
index 337274a9..846fe197 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs
@@ -19,6 +19,10 @@ using Microsoft.SqlTools.Test.Utility;
using Moq;
using Moq.Protected;
using Xunit;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution;
+using Microsoft.SqlTools.ServiceLayer.SqlContext;
+using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution;
+using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
{
@@ -994,5 +998,161 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
Assert.NotNull(errorMessage);
Assert.NotEmpty(errorMessage);
}
+
+ [Fact]
+ public async void ConnectingTwiceWithTheSameUriDoesNotCreateAnotherDbConnection()
+ {
+ // Setup the connect and disconnect params
+ var connectParamsSame1 = new ConnectParams()
+ {
+ OwnerUri = "connectParamsSame",
+ Connection = TestObjects.GetTestConnectionDetails()
+ };
+ var connectParamsSame2 = new ConnectParams()
+ {
+ OwnerUri = "connectParamsSame",
+ Connection = TestObjects.GetTestConnectionDetails()
+ };
+ var disconnectParamsSame = new DisconnectParams()
+ {
+ OwnerUri = connectParamsSame1.OwnerUri
+ };
+ var connectParamsDifferent = new ConnectParams()
+ {
+ OwnerUri = "connectParamsDifferent",
+ Connection = TestObjects.GetTestConnectionDetails()
+ };
+ var disconnectParamsDifferent = new DisconnectParams()
+ {
+ OwnerUri = connectParamsDifferent.OwnerUri
+ };
+
+ // Given a request to connect to a database, there should be no initial connections in the map
+ var service = TestObjects.GetTestConnectionService();
+ Dictionary ownerToConnectionMap = service.OwnerToConnectionMap;
+ Assert.Equal(0, ownerToConnectionMap.Count);
+
+ // If we connect to the service, there should be 1 connection
+ await service.Connect(connectParamsSame1);
+ Assert.Equal(1, ownerToConnectionMap.Count);
+
+ // If we connect again with the same URI, there should still be 1 connection
+ await service.Connect(connectParamsSame2);
+ Assert.Equal(1, ownerToConnectionMap.Count);
+
+ // If we connect with a different URI, there should be 2 connections
+ await service.Connect(connectParamsDifferent);
+ Assert.Equal(2, ownerToConnectionMap.Count);
+
+ // If we disconenct with the unique URI, there should be 1 connection
+ service.Disconnect(disconnectParamsDifferent);
+ Assert.Equal(1, ownerToConnectionMap.Count);
+
+ // If we disconenct with the duplicate URI, there should be 0 connections
+ service.Disconnect(disconnectParamsSame);
+ Assert.Equal(0, ownerToConnectionMap.Count);
+ }
+
+ [Fact]
+ public async void DbConnectionDoesntLeakUponDisconnect()
+ {
+ // If we connect with a single URI and 2 connection types
+ var connectParamsDefault = new ConnectParams()
+ {
+ OwnerUri = "connectParams",
+ Connection = TestObjects.GetTestConnectionDetails(),
+ Type = ConnectionType.Default
+ };
+ var connectParamsQuery = new ConnectParams()
+ {
+ OwnerUri = "connectParams",
+ Connection = TestObjects.GetTestConnectionDetails(),
+ Type = ConnectionType.Query
+ };
+ var disconnectParams = new DisconnectParams()
+ {
+ OwnerUri = connectParamsDefault.OwnerUri
+ };
+ var service = TestObjects.GetTestConnectionService();
+ await service.Connect(connectParamsDefault);
+ await service.Connect(connectParamsQuery);
+
+ // We should have one ConnectionInfo and 2 DbConnections
+ ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri];
+ Assert.Equal(2, connectionInfo.CountConnections);
+ Assert.Equal(1, service.OwnerToConnectionMap.Count);
+
+ // If we record when the Default connecton calls Close()
+ bool defaultDisconnectCalled = false;
+ var mockDefaultConnection = new Mock { CallBase = true };
+ mockDefaultConnection.Setup(x => x.Close())
+ .Callback(() =>
+ {
+ defaultDisconnectCalled = true;
+ });
+ connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Default] = mockDefaultConnection.Object;
+
+ // And when the Query connecton calls Close()
+ bool queryDisconnectCalled = false;
+ var mockQueryConnection = new Mock { CallBase = true };
+ mockQueryConnection.Setup(x => x.Close())
+ .Callback(() =>
+ {
+ queryDisconnectCalled = true;
+ });
+ connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Query] = mockQueryConnection.Object;
+
+ // If we disconnect all open connections with the same URI as used above
+ service.Disconnect(disconnectParams);
+
+ // Close() should have gotten called for both DbConnections
+ Assert.True(defaultDisconnectCalled);
+ Assert.True(queryDisconnectCalled);
+
+ // And the maps that hold connection data should be empty
+ Assert.Equal(0, connectionInfo.CountConnections);
+ Assert.Equal(0, service.OwnerToConnectionMap.Count);
+ }
+
+ [Fact]
+ public async void ClosingQueryConnectionShouldLeaveDefaultConnectionOpen()
+ {
+ // Setup the connect and disconnect params
+ var connectParamsDefault = new ConnectParams()
+ {
+ OwnerUri = "connectParamsSame",
+ Connection = TestObjects.GetTestConnectionDetails(),
+ Type = ConnectionType.Default
+ };
+ var connectParamsQuery = new ConnectParams()
+ {
+ OwnerUri = connectParamsDefault.OwnerUri,
+ Connection = TestObjects.GetTestConnectionDetails(),
+ Type = ConnectionType.Query
+ };
+ var disconnectParamsQuery = new DisconnectParams()
+ {
+ OwnerUri = connectParamsDefault.OwnerUri,
+ Type = connectParamsQuery.Type
+ };
+
+ // If I connect a Default and a Query connection
+ var service = TestObjects.GetTestConnectionService();
+ Dictionary ownerToConnectionMap = service.OwnerToConnectionMap;
+ await service.Connect(connectParamsDefault);
+ await service.Connect(connectParamsQuery);
+ ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri];
+
+ // There should be 2 connections in the map
+ Assert.Equal(2, connectionInfo.CountConnections);
+
+ // If I Disconnect only the Query connection, there should be 1 connection in the map
+ service.Disconnect(disconnectParamsQuery);
+ Assert.Equal(1, connectionInfo.CountConnections);
+
+ // If I reconnect, there should be 2 again
+ await service.Connect(connectParamsQuery);
+ Assert.Equal(2, connectionInfo.CountConnections);
+ }
}
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs
index 9347857d..fd512b68 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs
@@ -82,6 +82,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
public static Query GetBasicExecutedQuery()
{
ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false);
+
+ // Query won't be able to request a new query DbConnection unless the ConnectionService has a
+ // ConnectionInfo with the same URI as the query, so we will manually set it
+ ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci;
+
Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory(new Dictionary()));
query.Execute();
query.ExecutionTask.Wait();
@@ -222,6 +227,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, StandardConnectionDetails);
}
+ public static ConnectionInfo CreateConnectedConnectionInfo(Dictionary[][] data, bool throwOnRead, string type = ConnectionType.Default)
+ {
+ ConnectionService connectionService = ConnectionService.Instance;
+ connectionService.OwnerToConnectionMap.Clear();
+ connectionService.ConnectionFactory = CreateMockFactory(data, throwOnRead);
+
+ ConnectParams connectParams = new ConnectParams()
+ {
+ Connection = StandardConnectionDetails,
+ OwnerUri = Common.OwnerUri,
+ Type = type
+ };
+
+ connectionService.Connect(connectParams).Wait();
+ return connectionService.OwnerToConnectionMap[OwnerUri];
+ }
+
#endregion
#region Service Mocking
@@ -233,12 +255,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// Create a place for the temp "files" to be written
storage = new Dictionary();
- // Create the connection factory with the dataset
- var factory = CreateTestConnectionInfo(data, throwOnRead).Factory;
-
// Mock the connection service
var connectionService = new Mock();
- ConnectionInfo ci = new ConnectionInfo(factory, OwnerUri, StandardConnectionDetails);
+ ConnectionInfo ci = CreateConnectedConnectionInfo(data, throwOnRead);
ConnectionInfo outValMock;
connectionService
.Setup(service => service.TryFindConnection(It.IsAny(), out outValMock))
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs
index c61237fa..fc1c8fde 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs
@@ -7,10 +7,16 @@
using System.Data.Common;
using System.Threading.Tasks;
+using System;
+using System.Collections.Generic;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
+using Microsoft.SqlTools.ServiceLayer.SqlContext;
+using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
+using Microsoft.SqlTools.Test.Utility;
+using Xunit;
namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
{
@@ -24,12 +30,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// If:
// ... I create a query with a udt column in the result set
ConnectionInfo connectionInfo = TestObjects.GetTestConnectionInfo();
- Query query = new Query(Common.UdtQuery, connectionInfo, new QueryExecutionSettings(), Common.GetFileStreamFactory());
+ Query query = new Query(Common.UdtQuery, connectionInfo, new QueryExecutionSettings(), Common.GetFileStreamFactory(new Dictionary()));
// If:
// ... I then execute the query
DateTime startTime = DateTime.Now;
- query.Execute().Wait();
+ query.Execute();
+ query.ExecutionTask.Wait();
// Then:
// ... The query should complete within 2 seconds since retry logic should not kick in
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs
index d2a67ede..fad77f8f 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs
@@ -163,7 +163,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution
// If:
// ... I create a query from two batches (with separator)
- ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false);
+ ConnectionInfo ci = Common.CreateConnectedConnectionInfo(null, false);
+
string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery);
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
Query query = new Query(queryText, ci, new QueryExecutionSettings(), fileStreamFactory);
@@ -280,6 +281,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution
// If:
// ... I create a query from an invalid batch
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true);
+ ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci;
+
var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary());
Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), fileStreamFactory);
BatchCallbackHelper(query,