diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/CancelTokenKey.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/CancelTokenKey.cs new file mode 100644 index 00000000..9976916b --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/CancelTokenKey.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Used to uniquely identify a CancellationTokenSource associated with both + /// a string URI and a string connection type. + /// + public class CancelTokenKey : CancelConnectParams, IEquatable + { + public override bool Equals(object obj) + { + CancelTokenKey other = obj as CancelTokenKey; + if (other == null) + { + return false; + } + + return other.OwnerUri == OwnerUri && other.Type == Type; + } + + public bool Equals(CancelTokenKey obj) + { + return obj.OwnerUri == OwnerUri && obj.Type == Type; + } + + public override int GetHashCode() + { + return OwnerUri.GetHashCode() ^ Type.GetHashCode(); + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs index 5d140b9a..168af4a7 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -4,8 +4,12 @@ // using System; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.Data.Common; +using System.Linq; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.Connection { @@ -23,7 +27,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection OwnerUri = ownerUri; ConnectionDetails = details; ConnectionId = Guid.NewGuid(); - IntellisenseMetrics = new InteractionMetrics(new int[] { 50, 100, 200, 500, 1000, 2000 }); + IntellisenseMetrics = new InteractionMetrics(new int[] {50, 100, 200, 500, 1000, 2000}); } /// @@ -39,17 +43,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// /// Factory used for creating the SQL connection associated with the connection info. /// - public ISqlConnectionFactory Factory {get; private set;} + public ISqlConnectionFactory Factory { get; private set; } /// /// Properties used for creating/opening the SQL connection. /// public ConnectionDetails ConnectionDetails { get; private set; } - + /// - /// The connection to the SQL database that commands will be run against. + /// A map containing all connections to the database that are associated with + /// this ConnectionInfo's OwnerUri. + /// This is internal for testing access only /// - public DbConnection SqlConnection { get; set; } + internal readonly ConcurrentDictionary ConnectionTypeToConnectionMap = + new ConcurrentDictionary(); /// /// Intellisense Metrics @@ -60,8 +67,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// Returns true is the db connection is to a SQL db /// public bool IsAzure { get; set; } - - /// + /// Returns true if the sql connection is to a DW instance /// public bool IsSqlDW { get; set; } @@ -71,5 +77,86 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// public int MajorVersion { get; set; } + /// + /// All DbConnection instances held by this ConnectionInfo + /// + public ICollection AllConnections + { + get + { + return ConnectionTypeToConnectionMap.Values; + } + } + + /// + /// All connection type strings held by this ConnectionInfo + /// + /// + public ICollection AllConnectionTypes + { + get + { + return ConnectionTypeToConnectionMap.Keys; + } + } + + /// + /// The count of DbConnectioninstances held by this ConnectionInfo + /// + public int CountConnections + { + get + { + return ConnectionTypeToConnectionMap.Count; + } + } + + /// + /// Try to get the DbConnection associated with the given connection type string. + /// + /// true if a connection with type connectionType was located and out connection was set, + /// false otherwise + /// Thrown when connectionType is null or empty + public bool TryGetConnection(string connectionType, out DbConnection connection) + { + Validate.IsNotNullOrEmptyString("Connection Type", connectionType); + return ConnectionTypeToConnectionMap.TryGetValue(connectionType, out connection); + } + + /// + /// Adds a DbConnection to this object and associates it with the given + /// connection type string. If a connection already exists with an identical + /// connection type string, it is not overwritten. Ignores calls where connectionType = null + /// + /// Thrown when connectionType is null or empty + public void AddConnection(string connectionType, DbConnection connection) + { + Validate.IsNotNullOrEmptyString("Connection Type", connectionType); + ConnectionTypeToConnectionMap.TryAdd(connectionType, connection); + } + + /// + /// Removes the single DbConnection instance associated with string connectionType + /// + /// Thrown when connectionType is null or empty + public void RemoveConnection(string connectionType) + { + Validate.IsNotNullOrEmptyString("Connection Type", connectionType); + DbConnection connection; + ConnectionTypeToConnectionMap.TryRemove(connectionType, out connection); + } + + /// + /// Removes all DbConnection instances held by this object + /// + public void RemoveAllConnections() + { + foreach (var type in AllConnectionTypes) + { + DbConnection connection; + ConnectionTypeToConnectionMap.TryRemove(type, out connection); + } + } + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index f9092c70..2ca48d8c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -52,7 +52,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection private readonly Dictionary ownerToConnectionMap = new Dictionary(); - private readonly ConcurrentDictionary ownerToCancellationTokenSourceMap = new ConcurrentDictionary(); + /// + /// A map containing all CancellationTokenSource objects that are associated with a given URI/ConnectionType pair. + /// Entries in this map correspond to DbConnection instances that are in the process of connecting. + /// + private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap = + new ConcurrentDictionary(); private readonly object cancellationTokenSourceLock = new object(); @@ -120,6 +125,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } return this.connectionFactory; } + + internal set { this.connectionFactory = value; } } /// @@ -138,12 +145,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } /// - /// Open a connection with the specified connection details + /// Validates the given ConnectParams object. /// - /// - public async Task Connect(ConnectParams connectionParams) + /// The params to validate + /// A ConnectionCompleteParams object upon validation error, + /// null upon validation success + public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams) { - // Validate parameters string paramValidationErrorMessage; if (connectionParams == null) { @@ -161,29 +169,139 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection }; } - // Resolve if it is an existing connection - // Disconnect active connection if the URI is already connected + // return null upon success + return null; + } + + /// + /// Open a connection with the specified ConnectParams + /// + public async Task Connect(ConnectParams connectionParams) + { + // Validate parameters + ConnectionCompleteParams validationResults = ValidateConnectParams(connectionParams); + if (validationResults != null) + { + return validationResults; + } + + // If there is no ConnectionInfo in the map, create a new ConnectionInfo, + // but wait until later when we are connected to add it to the map. ConnectionInfo connectionInfo; - if (ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo) ) + if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo)) + { + connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); + } + + // Resolve if it is an existing connection + // Disconnect active connection if the URI is already connected for this connection type + DbConnection existingConnection; + if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection)) { var disconnectParams = new DisconnectParams() { - OwnerUri = connectionParams.OwnerUri + OwnerUri = connectionParams.OwnerUri, + Type = connectionParams.Type }; Disconnect(disconnectParams); } - connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); - // try to connect - var response = new ConnectionCompleteParams {OwnerUri = connectionParams.OwnerUri}; + // Try to open a connection with the given ConnectParams + ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams); + if (response != null) + { + return response; + } + + // If this is the first connection for this URI, add the ConnectionInfo to the map + if (!ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri)) + { + ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; + } + + // Invoke callback notifications + InvokeOnConnectionActivities(connectionInfo, connectionParams); + + // Return information about the connected SQL Server instance + return GetConnectionCompleteParams(connectionParams.Type, connectionInfo); + } + + /// + /// Creates a ConnectionCompleteParams as a response to a successful connection. + /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo. + /// + /// A ConnectionCompleteParams in response to the successful connection + private ConnectionCompleteParams GetConnectionCompleteParams(string connectionType, ConnectionInfo connectionInfo) + { + ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionType }; + + try + { + DbConnection connection; + connectionInfo.TryGetConnection(connectionType, out connection); + + // Update with the actual database name in connectionInfo and result + // Doing this here as we know the connection is open - expect to do this only on connecting + connectionInfo.ConnectionDetails.DatabaseName = connection.Database; + response.ConnectionSummary = new ConnectionSummary + { + ServerName = connectionInfo.ConnectionDetails.ServerName, + DatabaseName = connectionInfo.ConnectionDetails.DatabaseName, + UserName = connectionInfo.ConnectionDetails.UserName, + }; + + response.ConnectionId = connectionInfo.ConnectionId.ToString(); + + var reliableConnection = connection as ReliableSqlConnection; + DbConnection underlyingConnection = reliableConnection != null + ? reliableConnection.GetUnderlyingConnection() + : connection; + + ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(underlyingConnection); + response.ServerInfo = new ServerInfo + { + ServerMajorVersion = serverInfo.ServerMajorVersion, + ServerMinorVersion = serverInfo.ServerMinorVersion, + ServerReleaseVersion = serverInfo.ServerReleaseVersion, + EngineEditionId = serverInfo.EngineEditionId, + ServerVersion = serverInfo.ServerVersion, + ServerLevel = serverInfo.ServerLevel, + ServerEdition = serverInfo.ServerEdition, + IsCloud = serverInfo.IsCloud, + AzureVersion = serverInfo.AzureVersion, + OsVersion = serverInfo.OsVersion + }; + connectionInfo.IsAzure = serverInfo.IsCloud; + connectionInfo.MajorVersion = serverInfo.ServerMajorVersion; + connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse); + } + catch (Exception ex) + { + response.Messages = ex.ToString(); + } + + return response; + } + + /// + /// Tries to create and open a connection with the given ConnectParams. + /// + /// null upon success, a ConnectionCompleteParams detailing the error upon failure + private async Task TryOpenConnection(ConnectionInfo connectionInfo, ConnectParams connectionParams) + { CancellationTokenSource source = null; + DbConnection connection = null; + CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type }; + ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type }; + try { // build the connection string from the input parameters string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); // create a sql connection instance - connectionInfo.SqlConnection = connectionInfo.Factory.CreateSqlConnection(connectionString); + connection = connectionInfo.Factory.CreateSqlConnection(connectionString); + connectionInfo.AddConnection(connectionParams.Type, connection); // Add a cancellation token source so that the connection OpenAsync() can be cancelled using (source = new CancellationTokenSource()) @@ -193,11 +311,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { // If the URI is currently connecting from a different request, cancel it before we try to connect CancellationTokenSource currentSource; - if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out currentSource)) + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource)) { currentSource.Cancel(); } - ownerToCancellationTokenSourceMap[connectionParams.OwnerUri] = source; + cancelTupleToCancellationTokenSourceMap[cancelKey] = source; } // Create a task to handle cancellation requests @@ -214,10 +332,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } }); - var openTask = Task.Run(async () => { - await connectionInfo.SqlConnection.OpenAsync(source.Token); + var openTask = Task.Run(async () => + { + await connection.OpenAsync(source.Token); }); - + // Open the connection await Task.WhenAny(openTask, cancellationTask).Unwrap(); source.Cancel(); @@ -250,60 +369,61 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { // Only remove the token from the map if it is the same one created by this request CancellationTokenSource sourceValue; - if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out sourceValue) && sourceValue == source) + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out sourceValue) && sourceValue == source) { - ownerToCancellationTokenSourceMap.TryRemove(connectionParams.OwnerUri, out sourceValue); + cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue); } } } - ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; + // Return null upon success + return null; + } - // Update with the actual database name in connectionInfo and result - // Doing this here as we know the connection is open - expect to do this only on connecting - connectionInfo.ConnectionDetails.DatabaseName = connectionInfo.SqlConnection.Database; - response.ConnectionSummary = new ConnectionSummary + /// + /// Gets the existing connection with the given URI and connection type string. If none exists, + /// creates a new connection. This cannot be used to create a default connection or to create a + /// connection if a default connection does not exist. + /// + public async Task GetOrOpenConnection(string ownerUri, string connectionType) + { + if (string.IsNullOrEmpty(ownerUri) || string.IsNullOrEmpty(connectionType)) { - ServerName = connectionInfo.ConnectionDetails.ServerName, - DatabaseName = connectionInfo.ConnectionDetails.DatabaseName, - UserName = connectionInfo.ConnectionDetails.UserName, - }; + return null; + } - // invoke callback notifications - InvokeOnConnectionActivities(connectionInfo); - - // try to get information about the connected SQL Server instance - try + // Try to get the ConnectionInfo, if it exists + ConnectionInfo connectionInfo = ownerToConnectionMap[ownerUri]; + if (connectionInfo == null) { - var reliableConnection = connectionInfo.SqlConnection as ReliableSqlConnection; - DbConnection connection = reliableConnection != null ? reliableConnection.GetUnderlyingConnection() : connectionInfo.SqlConnection; - - ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(connection); - response.ServerInfo = new ServerInfo + return null; + } + + // Make sure a default connection exists + DbConnection defaultConnection; + if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection)) + { + return null; + } + + // Try to get the DbConnection + DbConnection connection; + if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType) + { + // If the DbConnection does not exist and is not the default connection, create one. + // We can't create the default (initial) connection here because we won't have a ConnectionDetails + // if Connect() has not yet been called. + ConnectParams connectParams = new ConnectParams() { - ServerMajorVersion = serverInfo.ServerMajorVersion, - ServerMinorVersion = serverInfo.ServerMinorVersion, - ServerReleaseVersion = serverInfo.ServerReleaseVersion, - EngineEditionId = serverInfo.EngineEditionId, - ServerVersion = serverInfo.ServerVersion, - ServerLevel = serverInfo.ServerLevel, - ServerEdition = serverInfo.ServerEdition, - IsCloud = serverInfo.IsCloud, - AzureVersion = serverInfo.AzureVersion, - OsVersion = serverInfo.OsVersion + OwnerUri = ownerUri, + Connection = connectionInfo.ConnectionDetails, + Type = connectionType }; - connectionInfo.IsAzure = serverInfo.IsCloud; - connectionInfo.MajorVersion = serverInfo.ServerMajorVersion; - connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse); - } - catch(Exception ex) - { - response.Messages = ex.ToString(); + await Connect(connectParams); + connectionInfo.TryGetConnection(connectionType, out connection); } - // return the connection result - response.ConnectionId = connectionInfo.ConnectionId.ToString(); - return response; + return connection; } /// @@ -317,9 +437,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return false; } + CancelTokenKey cancelKey = new CancelTokenKey + { + OwnerUri = cancelParams.OwnerUri, + Type = cancelParams.Type + }; + // Cancel any current connection attempts for this URI CancellationTokenSource source; - if (ownerToCancellationTokenSourceMap.TryGetValue(cancelParams.OwnerUri, out source)) + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out source)) { try { @@ -331,10 +457,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return false; } } - else - { - return false; - } + + return false; } /// @@ -349,58 +473,128 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } // Cancel if we are in the middle of connecting - if (CancelConnect(new CancelConnectParams() { OwnerUri = disconnectParams.OwnerUri })) + if (CancelConnections(disconnectParams.OwnerUri, disconnectParams.Type)) { return false; } - // Lookup the connection owned by the URI + // Lookup the ConnectionInfo owned by the URI ConnectionInfo info; if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info)) { return false; } - if (ServiceHost != null) + // Call Close() on the connections we want to disconnect + // If no connections were located, return false + if (!CloseConnections(info, disconnectParams.Type)) + { + return false; + } + + // Remove the disconnected connections from the ConnectionInfo map + if (disconnectParams.Type == null) + { + info.RemoveAllConnections(); + } + else + { + info.RemoveConnection(disconnectParams.Type); + } + + // If the ConnectionInfo has no more connections, remove the ConnectionInfo + if (info.CountConnections == 0) + { + ownerToConnectionMap.Remove(disconnectParams.OwnerUri); + } + + // Handle Telemetry disconnect events if we are disconnecting the default connection + if (disconnectParams.Type == null || disconnectParams.Type == ConnectionType.Default) + { + HandleDisconnectTelemetry(info); + InvokeOnDisconnectionActivities(info); + } + + // Return true upon success + return true; + } + + /// + /// Cancel connections associated with the given ownerUri. + /// If connectionType is not null, cancel the connection with the given connectionType + /// If connectionType is null, cancel all pending connections associated with ownerUri. + /// + /// true if a single pending connection associated with the non-null connectionType was + /// found and cancelled, false otherwise + private bool CancelConnections(string ownerUri, string connectionType) + { + // Cancel the connection of the given type + if (connectionType != null) + { + // If we are trying to disconnect a specific connection and it was just cancelled, + // this will return true + return CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = connectionType }); + } + + // Cancel all pending connections + foreach (var entry in cancelTupleToCancellationTokenSourceMap) + { + string entryConnectionUri = entry.Key.OwnerUri; + string entryConnectionType = entry.Key.Type; + if (ownerUri.Equals(entryConnectionUri)) + { + CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = entryConnectionType }); + } + } + + return false; + } + + /// + /// Closes DbConnections associated with the given ConnectionInfo. + /// If connectionType is not null, closes the DbConnection with the type given by connectionType. + /// If connectionType is null, closes all DbConnections. + /// + /// true if connections were found and attempted to be closed, + /// false if no connections were found + private bool CloseConnections(ConnectionInfo connectionInfo, string connectionType) + { + ICollection connectionsToDisconnect = new List(); + if (connectionType == null) + { + connectionsToDisconnect = connectionInfo.AllConnections; + } + else + { + // Make sure there is an existing connection of this type + DbConnection connection; + if (!connectionInfo.TryGetConnection(connectionType, out connection)) + { + return false; + } + connectionsToDisconnect.Add(connection); + } + + if (connectionsToDisconnect.Count == 0) + { + return false; + } + + foreach (DbConnection connection in connectionsToDisconnect) { try { - // Send a telemetry notification for intellisense performance metrics - ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams() - { - Params = new TelemetryProperties - { - Properties = new Dictionary - { - { "IsAzure", info.IsAzure ? "1" : "0" } - }, - EventName = TelemetryEventNames.IntellisenseQuantile, - Measures = info.IntellisenseMetrics.Quantile - } - }); + connection.Close(); } - catch (Exception ex) + catch (Exception) { - Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString()); + // Ignore } } - // Close the connection - info.SqlConnection.Close(); - - // Remove URI mapping - ownerToConnectionMap.Remove(disconnectParams.OwnerUri); - - // Invoke callback notifications - foreach (var activity in this.onDisconnectActivities) - { - activity(info.ConnectionDetails, disconnectParams.OwnerUri); - } - - // Success return true; } - + /// /// List all databases on the server specified /// @@ -486,11 +680,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { onDisconnectActivities.Add(activity); } - + /// /// Handle new connection requests /// - /// + /// /// /// protected async Task HandleConnectRequest( @@ -517,8 +711,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { try { + // result is null if the ConnectParams was successfully validated + ConnectionCompleteParams result = ValidateConnectParams(connectParams); + if (result != null) + { + await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); + return; + } + // open connection based on request details - ConnectionCompleteParams result = await Instance.Connect(connectParams); + result = await Instance.Connect(connectParams); await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); } catch (Exception ex) @@ -744,10 +946,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { try { - if (info.SqlConnection.State == ConnectionState.Open) + foreach (DbConnection connection in info.AllConnections) { - info.SqlConnection.ChangeDatabase(newDatabaseName); + if (connection.State == ConnectionState.Open) + { + connection.ChangeDatabase(newDatabaseName); + } } + info.ConnectionDetails.DatabaseName = newDatabaseName; // Fire a connection changed event @@ -771,13 +977,65 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } - private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo) + /// + /// Invokes the initial on-connect activities if the provided ConnectParams represents the default + /// connection. + /// + private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams) { + if (connectParams.Type != ConnectionType.Default) + { + return; + } + foreach (var activity in this.onConnectionActivities) { // not awaiting here to allow handlers to run in the background activity(connectionInfo); } } + + /// + /// Invokes the fianl on-disconnect activities if the provided DisconnectParams represents the default + /// connection or is null - representing that all connections are being disconnected. + /// + private void InvokeOnDisconnectionActivities(ConnectionInfo connectionInfo) + { + foreach (var activity in this.onDisconnectActivities) + { + activity(connectionInfo.ConnectionDetails, connectionInfo.OwnerUri); + } + } + + /// + /// Handles the Telemetry events that occur upon disconnect. + /// + /// + private void HandleDisconnectTelemetry(ConnectionInfo connectionInfo) + { + if (ServiceHost != null) + { + try + { + // Send a telemetry notification for intellisense performance metrics + ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams() + { + Params = new TelemetryProperties + { + Properties = new Dictionary + { + {"IsAzure", connectionInfo.IsAzure ? "1" : "0"} + }, + EventName = TelemetryEventNames.IntellisenseQuantile, + Measures = connectionInfo.IntellisenseMetrics.Quantile + } + }); + } + catch (Exception ex) + { + Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString()); + } + } + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs new file mode 100644 index 00000000..aabbb7d9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionType.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// String constants that represent connection types. + /// + /// Default: Connection used by the editor. Opened by the editor upon the initial connection. + /// Query: Connection used for executing queries. Opened when the first query is executed. + /// + public static class ConnectionType + { + public const string Default = "Default"; + public const string Query = "Query"; + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs index 9f2efdb0..9c76e590 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs @@ -14,6 +14,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace /// or a virtual file representing an object in a database. /// - public string OwnerUri { get; set; } + public string OwnerUri { get; set; } + + /// + /// The type of connection we are trying to cancel + /// + public string Type { get; set; } = ConnectionType.Default; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs index 31dad8c5..58d38ee4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs @@ -22,5 +22,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// connection properties to the same database. /// public ConnectionDetails Connection { get; set; } + + /// + /// The type of this connection. By default, this is set to ConnectionType.Default. + /// + public string Type { get; set; } = ConnectionType.Default; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs index 50517a52..0203a110 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs @@ -47,6 +47,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// Gets or sets the actual Connection established, including Database Name /// public ConnectionSummary ConnectionSummary { get; set; } + + /// + /// The type of connection that this notification is for + /// + public string Type { get; set; } = ConnectionType.Default; } /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs index 91bc7faf..d7e5de72 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/DisconnectParams.cs @@ -15,5 +15,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// or a virtual file representing an object in a database. /// public string OwnerUri { get; set; } + + /// + /// The type of connection we are disconnecting. If null, we will disconnect all connections. + /// connections. + /// + public string Type { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index 48b09b20..f9d2ad10 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -15,6 +15,7 @@ using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using System.Collections.Generic; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution @@ -369,82 +370,62 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - // Open up a connection for querying the database - string connectionString = ConnectionService.BuildConnectionString(editorConnection.ConnectionDetails); - // TODO: Don't create a new connection every time, see TFS #834978 - using (DbConnection conn = editorConnection.Factory.CreateSqlConnection(connectionString)) + // Locate and setup the connection + DbConnection queryConnection = await ConnectionService.Instance.GetOrOpenConnection(editorConnection.OwnerUri, ConnectionType.Query); + ReliableSqlConnection sqlConn = queryConnection as ReliableSqlConnection; + if (sqlConn != null) { - try + // Subscribe to database informational messages + sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage; + } + + try + { + // Execute beforeBatches synchronously, before the user defined batches + foreach (Batch b in BeforeBatches) { - await conn.OpenAsync(); - } - catch (Exception exception) - { - this.HasExecuted = true; - if (QueryConnectionException != null) - { - await QueryConnectionException(exception.Message); - } - return; + await b.Execute(queryConnection, cancellationSource.Token); } - ReliableSqlConnection sqlConn = conn as ReliableSqlConnection; + // We need these to execute synchronously, otherwise the user will be very unhappy + foreach (Batch b in Batches) + { + // Add completion callbacks + b.BatchStart += BatchStarted; + b.BatchCompletion += BatchCompleted; + b.BatchMessageSent += BatchMessageSent; + b.ResultSetCompletion += ResultSetCompleted; + await b.Execute(queryConnection, cancellationSource.Token); + } + + // Execute afterBatches synchronously, after the user defined batches + foreach (Batch b in AfterBatches) + { + await b.Execute(queryConnection, cancellationSource.Token); + } + + // Call the query execution callback + if (QueryCompleted != null) + { + await QueryCompleted(this); + } + } + catch (Exception) + { + // Call the query failure callback + if (QueryFailed != null) + { + await QueryFailed(this); + } + } + finally + { if (sqlConn != null) { // Subscribe to database informational messages - sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage; + sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage; } - - try - { - // Execute beforeBatches synchronously, before the user defined batches - foreach (Batch b in BeforeBatches) - { - await b.Execute(conn, cancellationSource.Token); - } - - // We need these to execute synchronously, otherwise the user will be very unhappy - foreach (Batch b in Batches) - { - // Add completion callbacks - b.BatchStart += BatchStarted; - b.BatchCompletion += BatchCompleted; - b.BatchMessageSent += BatchMessageSent; - b.ResultSetCompletion += ResultSetCompleted; - await b.Execute(conn, cancellationSource.Token); - } - - // Execute afterBatches synchronously, after the user defined batches - foreach (Batch b in AfterBatches) - { - await b.Execute(conn, cancellationSource.Token); - } - - // Call the query execution callback - if (QueryCompleted != null) - { - await QueryCompleted(this); - } - } - catch (Exception) - { - // Call the query failure callback - if (QueryFailed != null) - { - await QueryFailed(this); - } - } - finally - { - if (sqlConn != null) - { - // Subscribe to database informational messages - sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage; - } - } - - // TODO: Close connection after eliminating using statement for above TODO - } + } } /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ConnectionServiceTests.cs new file mode 100644 index 00000000..b78a2770 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ConnectionServiceTests.cs @@ -0,0 +1,104 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Data.Common; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.Test.Utility; +using Xunit; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection +{ + /// + /// Tests for the ServiceHost Connection Service tests that require a live database connection + /// + public class ConnectionServiceTests + { + [Fact] + public async Task RunningMultipleQueriesCreatesOnlyOneConnection() + { + // Connect/disconnect twice to ensure reconnection can occur + ConnectionService service = ConnectionService.Instance; + service.OwnerToConnectionMap.Clear(); + for (int i = 0; i < 2; i++) + { + var result = await TestObjects.InitLiveConnectionInfo(); + ConnectionInfo connectionInfo = result.ConnectionInfo; + string uri = connectionInfo.OwnerUri; + + // We should see one ConnectionInfo and one DbConnection + Assert.Equal(1, connectionInfo.CountConnections); + Assert.Equal(1, service.OwnerToConnectionMap.Count); + + // If we run a query + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory); + query.Execute(); + query.ExecutionTask.Wait(); + + // We should see two DbConnections + Assert.Equal(2, connectionInfo.CountConnections); + + // If we run another query + query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory); + query.Execute(); + query.ExecutionTask.Wait(); + + // We should still have 2 DbConnections + Assert.Equal(2, connectionInfo.CountConnections); + + // If we disconnect, we should remain in a consistent state to do it over again + // e.g. loop and do it over again + service.Disconnect(new DisconnectParams() { OwnerUri = connectionInfo.OwnerUri }); + + // We should be left with an empty connection map + Assert.Equal(0, service.OwnerToConnectionMap.Count); + } + } + + [Fact] + public async Task DatabaseChangesAffectAllConnections() + { + // If we make a connection to a live database + ConnectionService service = ConnectionService.Instance; + var result = await TestObjects.InitLiveConnectionInfo(); + ConnectionInfo connectionInfo = result.ConnectionInfo; + ConnectionDetails details = connectionInfo.ConnectionDetails; + string uri = connectionInfo.OwnerUri; + string initialDatabaseName = details.DatabaseName; + string newDatabaseName = "tempdb"; + string changeDatabaseQuery = "use " + newDatabaseName; + + // Then run any query to create a query DbConnection + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + Query query = new Query(Common.StandardQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory); + query.Execute(); + query.ExecutionTask.Wait(); + + // All open DbConnections (Query and Default) should have initialDatabaseName as their database + foreach (DbConnection connection in connectionInfo.AllConnections) + { + Assert.Equal(connection.Database, initialDatabaseName); + } + + // If we run a query to change the database + query = new Query(changeDatabaseQuery, connectionInfo, new QueryExecutionSettings(), fileStreamFactory); + query.Execute(); + query.ExecutionTask.Wait(); + + // All open DbConnections (Query and Default) should have newDatabaseName as their database + foreach (DbConnection connection in connectionInfo.AllConnections) + { + Assert.Equal(connection.Database, newDatabaseName); + } + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs index bd74e613..6abd18d5 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Connection/ReliableConnectionTests.cs @@ -685,19 +685,21 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection { var result = await TestObjects.InitLiveConnectionInfo(); ConnectionInfo connInfo = result.ConnectionInfo; + DbConnection connection = connInfo.ConnectionTypeToConnectionMap[ConnectionType.Default]; - Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connInfo.SqlConnection)); + + Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(connection)); SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(); Assert.True(ReliableConnectionHelper.IsAuthenticatingDatabaseMaster(builder)); ReliableConnectionHelper.TryAddAlwaysOnConnectionProperties(builder, new SqlConnectionStringBuilder()); - Assert.NotNull(ReliableConnectionHelper.GetServerName(connInfo.SqlConnection)); - Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connInfo.SqlConnection)); + Assert.NotNull(ReliableConnectionHelper.GetServerName(connection)); + Assert.NotNull(ReliableConnectionHelper.ReadServerVersion(connection)); - Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connInfo.SqlConnection)); + Assert.NotNull(ReliableConnectionHelper.GetAsSqlConnection(connection)); - ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connInfo.SqlConnection); + ReliableConnectionHelper.ServerInfo info = ReliableConnectionHelper.GetServerVersion(connection); Assert.NotNull(ReliableConnectionHelper.IsVersionGreaterThan2012RTM(info)); } @@ -728,8 +730,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Connection { var result = await TestObjects.InitLiveConnectionInfo(); ConnectionInfo connInfo = result.ConnectionInfo; + DbConnection dbConnection; + connInfo.TryGetConnection(ConnectionType.Default, out dbConnection); - var connection = connInfo.SqlConnection as ReliableSqlConnection; + var connection = dbConnection as ReliableSqlConnection; var command = new ReliableSqlConnection.ReliableSqlCommand(connection); Assert.NotNull(command.Connection); diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs index 4bb145a6..45c38aa6 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/DataStorage/StorageDataReaderTests.cs @@ -19,8 +19,10 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution.DataSt private async Task GetTestStorageDataReader(string query) { var result = await TestObjects.InitLiveConnectionInfo(); + DbConnection connection; + result.ConnectionInfo.TryGetConnection(ConnectionType.Default, out connection); - var command = result.ConnectionInfo.SqlConnection.CreateCommand(); + var command = connection.CreateCommand(); command.CommandText = query; DbDataReader reader = command.ExecuteReader(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/ExecuteTests.cs new file mode 100644 index 00000000..5b25a9d5 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/QueryExecution/ExecuteTests.cs @@ -0,0 +1,112 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.QueryExecution +{ + public class ExecuteTests + { + [Fact] + public async Task RollbackTransactionFailsWithoutBeginTransaction() + { + const string refactorText = "ROLLBACK TRANSACTION"; + + // Given a connection to a live database + var result = await TestObjects.InitLiveConnectionInfo(); + ConnectionInfo connInfo = result.ConnectionInfo; + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + + // If I run a "ROLLBACK TRANSACTION" query + Query query = new Query(refactorText, connInfo, new QueryExecutionSettings(), fileStreamFactory); + query.Execute(); + query.ExecutionTask.Wait(); + + // There should be an error + Assert.True(query.Batches[0].HasError); + } + + [Fact] + public async Task TransactionsSucceedAcrossQueries() + { + const string beginText = "BEGIN TRANSACTION"; + const string rollbackText = "ROLLBACK TRANSACTION"; + + // Given a connection to a live database + var result = await TestObjects.InitLiveConnectionInfo(); + ConnectionInfo connInfo = result.ConnectionInfo; + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + + // If I run a "BEGIN TRANSACTION" query + CreateAndExecuteQuery(beginText, connInfo, fileStreamFactory); + + // Then I run a "ROLLBACK TRANSACTION" query, there should be no errors + Query rollbackQuery = CreateAndExecuteQuery(rollbackText, connInfo, fileStreamFactory); + Assert.False(rollbackQuery.Batches[0].HasError); + } + + [Fact] + public async Task TempTablesPersistAcrossQueries() + { + const string createTempText = "CREATE TABLE #someTempTable (id int)"; + const string insertTempText = "INSERT INTO #someTempTable VALUES(1)"; + + // Given a connection to a live database + var result = await TestObjects.InitLiveConnectionInfo(); + ConnectionInfo connInfo = result.ConnectionInfo; + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + + // If I run a query creating a temp table + CreateAndExecuteQuery(createTempText, connInfo, fileStreamFactory); + + // Then I run a different query using that temp table, there should be no errors + Query insertTempQuery = CreateAndExecuteQuery(insertTempText, connInfo, fileStreamFactory); + Assert.False(insertTempQuery.Batches[0].HasError); + } + + [Fact] + public async Task DatabaseChangesWhenCallingUseDatabase() + { + const string master = "master"; + const string tempdb = "tempdb"; + const string useQuery = "USE {0}"; + + // Given a connection to a live database + var result = await TestObjects.InitLiveConnectionInfo(); + ConnectionInfo connInfo = result.ConnectionInfo; + DbConnection connection; + connInfo.TryGetConnection(ConnectionType.Default, out connection); + + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); + + // If I use master, the current database should be master + CreateAndExecuteQuery(string.Format(useQuery, master), connInfo, fileStreamFactory); + Assert.Equal(master, connection.Database); + + // If I use tempdb, the current database should be tempdb + CreateAndExecuteQuery(string.Format(useQuery, tempdb), connInfo, fileStreamFactory); + Assert.Equal(tempdb, connection.Database); + + // If I switch back to master, the current database should be master + CreateAndExecuteQuery(string.Format(useQuery, master), connInfo, fileStreamFactory); + Assert.Equal(master, connection.Database); + } + + public Query CreateAndExecuteQuery(string queryText, ConnectionInfo connectionInfo, IFileStreamFactory fileStreamFactory) + { + Query query = new Query(queryText, connectionInfo, new QueryExecutionSettings(), fileStreamFactory); + query.Execute(); + query.ExecutionTask.Wait(); + return query; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs index 337274a9..846fe197 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs @@ -19,6 +19,10 @@ using Microsoft.SqlTools.Test.Utility; using Moq; using Moq.Protected; using Xunit; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; namespace Microsoft.SqlTools.ServiceLayer.Test.Connection { @@ -994,5 +998,161 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection Assert.NotNull(errorMessage); Assert.NotEmpty(errorMessage); } + + [Fact] + public async void ConnectingTwiceWithTheSameUriDoesNotCreateAnotherDbConnection() + { + // Setup the connect and disconnect params + var connectParamsSame1 = new ConnectParams() + { + OwnerUri = "connectParamsSame", + Connection = TestObjects.GetTestConnectionDetails() + }; + var connectParamsSame2 = new ConnectParams() + { + OwnerUri = "connectParamsSame", + Connection = TestObjects.GetTestConnectionDetails() + }; + var disconnectParamsSame = new DisconnectParams() + { + OwnerUri = connectParamsSame1.OwnerUri + }; + var connectParamsDifferent = new ConnectParams() + { + OwnerUri = "connectParamsDifferent", + Connection = TestObjects.GetTestConnectionDetails() + }; + var disconnectParamsDifferent = new DisconnectParams() + { + OwnerUri = connectParamsDifferent.OwnerUri + }; + + // Given a request to connect to a database, there should be no initial connections in the map + var service = TestObjects.GetTestConnectionService(); + Dictionary ownerToConnectionMap = service.OwnerToConnectionMap; + Assert.Equal(0, ownerToConnectionMap.Count); + + // If we connect to the service, there should be 1 connection + await service.Connect(connectParamsSame1); + Assert.Equal(1, ownerToConnectionMap.Count); + + // If we connect again with the same URI, there should still be 1 connection + await service.Connect(connectParamsSame2); + Assert.Equal(1, ownerToConnectionMap.Count); + + // If we connect with a different URI, there should be 2 connections + await service.Connect(connectParamsDifferent); + Assert.Equal(2, ownerToConnectionMap.Count); + + // If we disconenct with the unique URI, there should be 1 connection + service.Disconnect(disconnectParamsDifferent); + Assert.Equal(1, ownerToConnectionMap.Count); + + // If we disconenct with the duplicate URI, there should be 0 connections + service.Disconnect(disconnectParamsSame); + Assert.Equal(0, ownerToConnectionMap.Count); + } + + [Fact] + public async void DbConnectionDoesntLeakUponDisconnect() + { + // If we connect with a single URI and 2 connection types + var connectParamsDefault = new ConnectParams() + { + OwnerUri = "connectParams", + Connection = TestObjects.GetTestConnectionDetails(), + Type = ConnectionType.Default + }; + var connectParamsQuery = new ConnectParams() + { + OwnerUri = "connectParams", + Connection = TestObjects.GetTestConnectionDetails(), + Type = ConnectionType.Query + }; + var disconnectParams = new DisconnectParams() + { + OwnerUri = connectParamsDefault.OwnerUri + }; + var service = TestObjects.GetTestConnectionService(); + await service.Connect(connectParamsDefault); + await service.Connect(connectParamsQuery); + + // We should have one ConnectionInfo and 2 DbConnections + ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri]; + Assert.Equal(2, connectionInfo.CountConnections); + Assert.Equal(1, service.OwnerToConnectionMap.Count); + + // If we record when the Default connecton calls Close() + bool defaultDisconnectCalled = false; + var mockDefaultConnection = new Mock { CallBase = true }; + mockDefaultConnection.Setup(x => x.Close()) + .Callback(() => + { + defaultDisconnectCalled = true; + }); + connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Default] = mockDefaultConnection.Object; + + // And when the Query connecton calls Close() + bool queryDisconnectCalled = false; + var mockQueryConnection = new Mock { CallBase = true }; + mockQueryConnection.Setup(x => x.Close()) + .Callback(() => + { + queryDisconnectCalled = true; + }); + connectionInfo.ConnectionTypeToConnectionMap[ConnectionType.Query] = mockQueryConnection.Object; + + // If we disconnect all open connections with the same URI as used above + service.Disconnect(disconnectParams); + + // Close() should have gotten called for both DbConnections + Assert.True(defaultDisconnectCalled); + Assert.True(queryDisconnectCalled); + + // And the maps that hold connection data should be empty + Assert.Equal(0, connectionInfo.CountConnections); + Assert.Equal(0, service.OwnerToConnectionMap.Count); + } + + [Fact] + public async void ClosingQueryConnectionShouldLeaveDefaultConnectionOpen() + { + // Setup the connect and disconnect params + var connectParamsDefault = new ConnectParams() + { + OwnerUri = "connectParamsSame", + Connection = TestObjects.GetTestConnectionDetails(), + Type = ConnectionType.Default + }; + var connectParamsQuery = new ConnectParams() + { + OwnerUri = connectParamsDefault.OwnerUri, + Connection = TestObjects.GetTestConnectionDetails(), + Type = ConnectionType.Query + }; + var disconnectParamsQuery = new DisconnectParams() + { + OwnerUri = connectParamsDefault.OwnerUri, + Type = connectParamsQuery.Type + }; + + // If I connect a Default and a Query connection + var service = TestObjects.GetTestConnectionService(); + Dictionary ownerToConnectionMap = service.OwnerToConnectionMap; + await service.Connect(connectParamsDefault); + await service.Connect(connectParamsQuery); + ConnectionInfo connectionInfo = service.OwnerToConnectionMap[connectParamsDefault.OwnerUri]; + + // There should be 2 connections in the map + Assert.Equal(2, connectionInfo.CountConnections); + + // If I Disconnect only the Query connection, there should be 1 connection in the map + service.Disconnect(disconnectParamsQuery); + Assert.Equal(1, connectionInfo.CountConnections); + + // If I reconnect, there should be 2 again + await service.Connect(connectParamsQuery); + Assert.Equal(2, connectionInfo.CountConnections); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 9347857d..fd512b68 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -82,6 +82,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public static Query GetBasicExecutedQuery() { ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); + + // Query won't be able to request a new query DbConnection unless the ConnectionService has a + // ConnectionInfo with the same URI as the query, so we will manually set it + ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci; + Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory(new Dictionary())); query.Execute(); query.ExecutionTask.Wait(); @@ -222,6 +227,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, StandardConnectionDetails); } + public static ConnectionInfo CreateConnectedConnectionInfo(Dictionary[][] data, bool throwOnRead, string type = ConnectionType.Default) + { + ConnectionService connectionService = ConnectionService.Instance; + connectionService.OwnerToConnectionMap.Clear(); + connectionService.ConnectionFactory = CreateMockFactory(data, throwOnRead); + + ConnectParams connectParams = new ConnectParams() + { + Connection = StandardConnectionDetails, + OwnerUri = Common.OwnerUri, + Type = type + }; + + connectionService.Connect(connectParams).Wait(); + return connectionService.OwnerToConnectionMap[OwnerUri]; + } + #endregion #region Service Mocking @@ -233,12 +255,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Create a place for the temp "files" to be written storage = new Dictionary(); - // Create the connection factory with the dataset - var factory = CreateTestConnectionInfo(data, throwOnRead).Factory; - // Mock the connection service var connectionService = new Mock(); - ConnectionInfo ci = new ConnectionInfo(factory, OwnerUri, StandardConnectionDetails); + ConnectionInfo ci = CreateConnectedConnectionInfo(data, throwOnRead); ConnectionInfo outValMock; connectionService .Setup(service => service.TryFindConnection(It.IsAny(), out outValMock)) diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index c61237fa..fc1c8fde 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -7,10 +7,16 @@ using System.Data.Common; using System.Threading.Tasks; +using System; +using System.Collections.Generic; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.Test.Utility; +using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { @@ -24,12 +30,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I create a query with a udt column in the result set ConnectionInfo connectionInfo = TestObjects.GetTestConnectionInfo(); - Query query = new Query(Common.UdtQuery, connectionInfo, new QueryExecutionSettings(), Common.GetFileStreamFactory()); + Query query = new Query(Common.UdtQuery, connectionInfo, new QueryExecutionSettings(), Common.GetFileStreamFactory(new Dictionary())); // If: // ... I then execute the query DateTime startTime = DateTime.Now; - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... The query should complete within 2 seconds since retry logic should not kick in diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs index d2a67ede..fad77f8f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/QueryTests.cs @@ -163,7 +163,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // If: // ... I create a query from two batches (with separator) - ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); + ConnectionInfo ci = Common.CreateConnectedConnectionInfo(null, false); + string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery); var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); Query query = new Query(queryText, ci, new QueryExecutionSettings(), fileStreamFactory); @@ -280,6 +281,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution // If: // ... I create a query from an invalid batch ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); + ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci; + var fileStreamFactory = Common.GetFileStreamFactory(new Dictionary()); Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), fileStreamFactory); BatchCallbackHelper(query,