From fb239ac9568f4ffb3e7d5766c18af6a09b579555 Mon Sep 17 00:00:00 2001 From: Matt Irvine Date: Mon, 1 May 2017 21:01:26 -0700 Subject: [PATCH] Support connecting with a connection string (#334) - Add support for connecting with a connection string by passing it as one of the connection parameters - If a connection string is present, it will override any other parameters that are present --- .../Connection/ConnectionService.cs | 2127 +++++++++-------- .../Connection/Contracts/ConnectParams.cs | 3 +- .../Contracts/ConnectParamsExtensions.cs | 5 + .../Connection/Contracts/ConnectionDetails.cs | 74 +- .../Connection/ConnectionServiceTests.cs | 40 +- .../Utility/TestObjects.cs | 34 +- 6 files changed, 1184 insertions(+), 1099 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 42e3c3c3..9c45ec30 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -1,1053 +1,1076 @@ -// -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. -// - -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Data; -using System.Data.Common; -using System.Data.SqlClient; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.SqlTools.Hosting.Protocol; -using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; -using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; -using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; -using Microsoft.SqlTools.ServiceLayer.SqlContext; -using Microsoft.SqlTools.ServiceLayer.Workspace; -using Microsoft.SqlServer.Management.Common; -using Microsoft.SqlTools.Utility; - -namespace Microsoft.SqlTools.ServiceLayer.Connection -{ - /// - /// Main class for the Connection Management services - /// - public class ConnectionService - { - /// - /// Singleton service instance - /// - private static readonly Lazy instance - = new Lazy(() => new ConnectionService()); - - /// - /// Gets the singleton service instance - /// - public static ConnectionService Instance - { - get - { - return instance.Value; - } - } - - /// - /// The SQL connection factory object - /// - private ISqlConnectionFactory connectionFactory; - - private readonly Dictionary ownerToConnectionMap = new Dictionary(); - - /// - /// A map containing all CancellationTokenSource objects that are associated with a given URI/ConnectionType pair. - /// Entries in this map correspond to DbConnection instances that are in the process of connecting. - /// - private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap = - new ConcurrentDictionary(); - - private readonly object cancellationTokenSourceLock = new object(); - - /// - /// Map from script URIs to ConnectionInfo objects - /// This is internal for testing access only - /// - internal Dictionary OwnerToConnectionMap - { - get - { - return this.ownerToConnectionMap; - } - } - - /// - /// Service host object for sending/receiving requests/events. - /// Internal for testing purposes. - /// - internal IProtocolEndpoint ServiceHost - { - get; - set; - } - - /// - /// Default constructor should be private since it's a singleton class, but we need a constructor - /// for use in unit test mocking. - /// - public ConnectionService() - { - } - - /// - /// Callback for onconnection handler - /// - /// - public delegate Task OnConnectionHandler(ConnectionInfo info); - - /// - /// Callback for ondisconnect handler - /// - public delegate Task OnDisconnectHandler(ConnectionSummary summary, string ownerUri); - - /// - /// List of onconnection handlers - /// - private readonly List onConnectionActivities = new List(); - - /// - /// List of ondisconnect handlers - /// - private readonly List onDisconnectActivities = new List(); - - /// - /// Gets the SQL connection factory instance - /// - public ISqlConnectionFactory ConnectionFactory - { - get - { - if (this.connectionFactory == null) - { - this.connectionFactory = new SqlConnectionFactory(); - } - return this.connectionFactory; - } - - internal set { this.connectionFactory = value; } - } - - /// - /// Test constructor that injects dependency interfaces - /// - /// - public ConnectionService(ISqlConnectionFactory testFactory) - { - this.connectionFactory = testFactory; - } - - // Attempts to link a URI to an actively used connection for this URI - public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) - { - return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo); - } - - /// - /// Validates the given ConnectParams object. - /// - /// The params to validate - /// A ConnectionCompleteParams object upon validation error, - /// null upon validation success - public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams) - { - string paramValidationErrorMessage; - if (connectionParams == null) - { - return new ConnectionCompleteParams - { - Messages = SR.ConnectionServiceConnectErrorNullParams - }; - } - if (!connectionParams.IsValid(out paramValidationErrorMessage)) - { - return new ConnectionCompleteParams - { - OwnerUri = connectionParams.OwnerUri, - Messages = paramValidationErrorMessage - }; - } - - // return null upon success - return null; - } - - /// - /// Open a connection with the specified ConnectParams - /// - public virtual async Task Connect(ConnectParams connectionParams) - { - // Validate parameters - ConnectionCompleteParams validationResults = ValidateConnectParams(connectionParams); - if (validationResults != null) - { - return validationResults; - } - - // If there is no ConnectionInfo in the map, create a new ConnectionInfo, - // but wait until later when we are connected to add it to the map. - ConnectionInfo connectionInfo; - bool connectionChanged = false; - if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo)) - { - connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); - } - else if (IsConnectionChanged(connectionParams, connectionInfo)) - { - // We are actively changing the connection information for this connection. We must disconnect - // all active connections, since it represents a full context change - connectionChanged = true; - } - - DisconnectExistingConnectionIfNeeded(connectionParams, connectionInfo, disconnectAll: connectionChanged); - - if (connectionChanged) - { - connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); - } - - // Try to open a connection with the given ConnectParams - ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams); - if (response != null) - { - return response; - } - - // If this is the first connection for this URI, add the ConnectionInfo to the map - bool addToMap = connectionChanged || !ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri); - if (addToMap) - { - ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; - } - - // Return information about the connected SQL Server instance - ConnectionCompleteParams completeParams = GetConnectionCompleteParams(connectionParams.Type, connectionInfo); - // Invoke callback notifications - InvokeOnConnectionActivities(connectionInfo, connectionParams); - - return completeParams; - } - - private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo) - { - if (connectionInfo.HasConnectionType(connectionParams.Type) - && !connectionInfo.ConnectionDetails.IsComparableTo(connectionParams.Connection)) - { - return true; - } - return false; - } - - private bool IsDefaultConnectionType(string connectionType) - { - return string.IsNullOrEmpty(connectionType) || ConnectionType.Default.Equals(connectionType, StringComparison.CurrentCultureIgnoreCase); - } - - private void DisconnectExistingConnectionIfNeeded(ConnectParams connectionParams, ConnectionInfo connectionInfo, bool disconnectAll) - { - // Resolve if it is an existing connection - // Disconnect active connection if the URI is already connected for this connection type - DbConnection existingConnection; - if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection)) - { - var disconnectParams = new DisconnectParams() - { - OwnerUri = connectionParams.OwnerUri, - Type = disconnectAll ? null : connectionParams.Type - }; - Disconnect(disconnectParams); - } - } - - /// - /// Creates a ConnectionCompleteParams as a response to a successful connection. - /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo. - /// - /// A ConnectionCompleteParams in response to the successful connection - private ConnectionCompleteParams GetConnectionCompleteParams(string connectionType, ConnectionInfo connectionInfo) - { - ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionType }; - - try - { - DbConnection connection; - connectionInfo.TryGetConnection(connectionType, out connection); - - // Update with the actual database name in connectionInfo and result - // Doing this here as we know the connection is open - expect to do this only on connecting - connectionInfo.ConnectionDetails.DatabaseName = connection.Database; - response.ConnectionSummary = new ConnectionSummary - { - ServerName = connectionInfo.ConnectionDetails.ServerName, - DatabaseName = connectionInfo.ConnectionDetails.DatabaseName, - UserName = connectionInfo.ConnectionDetails.UserName, - }; - - response.ConnectionId = connectionInfo.ConnectionId.ToString(); - - var reliableConnection = connection as ReliableSqlConnection; - DbConnection underlyingConnection = reliableConnection != null - ? reliableConnection.GetUnderlyingConnection() - : connection; - - ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(underlyingConnection); - response.ServerInfo = new ServerInfo - { - ServerMajorVersion = serverInfo.ServerMajorVersion, - ServerMinorVersion = serverInfo.ServerMinorVersion, - ServerReleaseVersion = serverInfo.ServerReleaseVersion, - EngineEditionId = serverInfo.EngineEditionId, - ServerVersion = serverInfo.ServerVersion, - ServerLevel = serverInfo.ServerLevel, - ServerEdition = serverInfo.ServerEdition, - IsCloud = serverInfo.IsCloud, - AzureVersion = serverInfo.AzureVersion, - OsVersion = serverInfo.OsVersion - }; - connectionInfo.IsAzure = serverInfo.IsCloud; - connectionInfo.MajorVersion = serverInfo.ServerMajorVersion; - connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse); - } - catch (Exception ex) - { - response.Messages = ex.ToString(); - } - - return response; - } - - /// - /// Tries to create and open a connection with the given ConnectParams. - /// - /// null upon success, a ConnectionCompleteParams detailing the error upon failure - private async Task TryOpenConnection(ConnectionInfo connectionInfo, ConnectParams connectionParams) - { - CancellationTokenSource source = null; - DbConnection connection = null; - CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type }; - ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type }; - - try - { - // build the connection string from the input parameters - string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); - - // create a sql connection instance - connection = connectionInfo.Factory.CreateSqlConnection(connectionString); - connectionInfo.AddConnection(connectionParams.Type, connection); - - // Add a cancellation token source so that the connection OpenAsync() can be cancelled - source = new CancellationTokenSource(); - // Locking here to perform two operations as one atomic operation - lock (cancellationTokenSourceLock) - { - // If the URI is currently connecting from a different request, cancel it before we try to connect - CancellationTokenSource currentSource; - if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource)) - { - currentSource.Cancel(); - } - cancelTupleToCancellationTokenSourceMap[cancelKey] = source; - } - - // Open the connection - await connection.OpenAsync(source.Token); - } - catch (SqlException ex) - { - response.ErrorNumber = ex.Number; - response.ErrorMessage = ex.Message; - response.Messages = ex.ToString(); - return response; - } - catch (OperationCanceledException) - { - // OpenAsync was cancelled - response.Messages = SR.ConnectionServiceConnectionCanceled; - return response; - } - catch (Exception ex) - { - response.ErrorMessage = ex.Message; - response.Messages = ex.ToString(); - return response; - } - finally - { - // Remove our cancellation token from the map since we're no longer connecting - // Using a lock here to perform two operations as one atomic operation - lock (cancellationTokenSourceLock) - { - // Only remove the token from the map if it is the same one created by this request - CancellationTokenSource sourceValue; - if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out sourceValue) && sourceValue == source) - { - cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue); - } - source?.Dispose(); - } - } - - // Return null upon success - return null; - } - - /// - /// Gets the existing connection with the given URI and connection type string. If none exists, - /// creates a new connection. This cannot be used to create a default connection or to create a - /// connection if a default connection does not exist. - /// - /// A DB connection for the connection type requested - public async Task GetOrOpenConnection(string ownerUri, string connectionType) - { - Validate.IsNotNullOrEmptyString(nameof(ownerUri), ownerUri); - Validate.IsNotNullOrEmptyString(nameof(connectionType), connectionType); - - // Try to get the ConnectionInfo, if it exists - ConnectionInfo connectionInfo; - if (!ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo)) - { - throw new ArgumentOutOfRangeException(SR.ConnectionServiceListDbErrorNotConnected(ownerUri)); - } - - // Make sure a default connection exists - DbConnection defaultConnection; - if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection)) - { - throw new InvalidOperationException(SR.ConnectionServiceDbErrorDefaultNotConnected(ownerUri)); - } - - // Try to get the DbConnection - DbConnection connection; - if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType) - { - // If the DbConnection does not exist and is not the default connection, create one. - // We can't create the default (initial) connection here because we won't have a ConnectionDetails - // if Connect() has not yet been called. - ConnectParams connectParams = new ConnectParams - { - OwnerUri = ownerUri, - Connection = connectionInfo.ConnectionDetails, - Type = connectionType - }; - await Connect(connectParams); - connectionInfo.TryGetConnection(connectionType, out connection); - } - - return connection; - } - - /// - /// Cancel a connection that is in the process of opening. - /// - public bool CancelConnect(CancelConnectParams cancelParams) - { - // Validate parameters - if (cancelParams == null || string.IsNullOrEmpty(cancelParams.OwnerUri)) - { - return false; - } - - CancelTokenKey cancelKey = new CancelTokenKey - { - OwnerUri = cancelParams.OwnerUri, - Type = cancelParams.Type - }; - - // Cancel any current connection attempts for this URI - CancellationTokenSource source; - if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out source)) - { - try - { - source.Cancel(); - return true; - } - catch - { - return false; - } - } - - return false; - } - - /// - /// Close a connection with the specified connection details. - /// +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Data.SqlClient; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts; +using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Workspace; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlTools.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + /// + /// Main class for the Connection Management services + /// + public class ConnectionService + { + /// + /// Singleton service instance + /// + private static readonly Lazy instance + = new Lazy(() => new ConnectionService()); + + /// + /// Gets the singleton service instance + /// + public static ConnectionService Instance + { + get + { + return instance.Value; + } + } + + /// + /// The SQL connection factory object + /// + private ISqlConnectionFactory connectionFactory; + + private readonly Dictionary ownerToConnectionMap = new Dictionary(); + + /// + /// A map containing all CancellationTokenSource objects that are associated with a given URI/ConnectionType pair. + /// Entries in this map correspond to DbConnection instances that are in the process of connecting. + /// + private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap = + new ConcurrentDictionary(); + + private readonly object cancellationTokenSourceLock = new object(); + + /// + /// Map from script URIs to ConnectionInfo objects + /// This is internal for testing access only + /// + internal Dictionary OwnerToConnectionMap + { + get + { + return this.ownerToConnectionMap; + } + } + + /// + /// Service host object for sending/receiving requests/events. + /// Internal for testing purposes. + /// + internal IProtocolEndpoint ServiceHost + { + get; + set; + } + + /// + /// Default constructor should be private since it's a singleton class, but we need a constructor + /// for use in unit test mocking. + /// + public ConnectionService() + { + } + + /// + /// Callback for onconnection handler + /// + /// + public delegate Task OnConnectionHandler(ConnectionInfo info); + + /// + /// Callback for ondisconnect handler + /// + public delegate Task OnDisconnectHandler(ConnectionSummary summary, string ownerUri); + + /// + /// List of onconnection handlers + /// + private readonly List onConnectionActivities = new List(); + + /// + /// List of ondisconnect handlers + /// + private readonly List onDisconnectActivities = new List(); + + /// + /// Gets the SQL connection factory instance + /// + public ISqlConnectionFactory ConnectionFactory + { + get + { + if (this.connectionFactory == null) + { + this.connectionFactory = new SqlConnectionFactory(); + } + return this.connectionFactory; + } + + internal set { this.connectionFactory = value; } + } + + /// + /// Test constructor that injects dependency interfaces + /// + /// + public ConnectionService(ISqlConnectionFactory testFactory) + { + this.connectionFactory = testFactory; + } + + // Attempts to link a URI to an actively used connection for this URI + public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) + { + return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo); + } + + /// + /// Validates the given ConnectParams object. + /// + /// The params to validate + /// A ConnectionCompleteParams object upon validation error, + /// null upon validation success + public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams) + { + string paramValidationErrorMessage; + if (connectionParams == null) + { + return new ConnectionCompleteParams + { + Messages = SR.ConnectionServiceConnectErrorNullParams + }; + } + if (!connectionParams.IsValid(out paramValidationErrorMessage)) + { + return new ConnectionCompleteParams + { + OwnerUri = connectionParams.OwnerUri, + Messages = paramValidationErrorMessage + }; + } + + // return null upon success + return null; + } + + /// + /// Open a connection with the specified ConnectParams + /// + public virtual async Task Connect(ConnectParams connectionParams) + { + // Validate parameters + ConnectionCompleteParams validationResults = ValidateConnectParams(connectionParams); + if (validationResults != null) + { + return validationResults; + } + + // If there is no ConnectionInfo in the map, create a new ConnectionInfo, + // but wait until later when we are connected to add it to the map. + ConnectionInfo connectionInfo; + bool connectionChanged = false; + if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo)) + { + connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); + } + else if (IsConnectionChanged(connectionParams, connectionInfo)) + { + // We are actively changing the connection information for this connection. We must disconnect + // all active connections, since it represents a full context change + connectionChanged = true; + } + + DisconnectExistingConnectionIfNeeded(connectionParams, connectionInfo, disconnectAll: connectionChanged); + + if (connectionChanged) + { + connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); + } + + // Try to open a connection with the given ConnectParams + ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams); + if (response != null) + { + return response; + } + + // If this is the first connection for this URI, add the ConnectionInfo to the map + bool addToMap = connectionChanged || !ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri); + if (addToMap) + { + ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; + } + + // Return information about the connected SQL Server instance + ConnectionCompleteParams completeParams = GetConnectionCompleteParams(connectionParams.Type, connectionInfo); + // Invoke callback notifications + InvokeOnConnectionActivities(connectionInfo, connectionParams); + + return completeParams; + } + + private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo) + { + if (connectionInfo.HasConnectionType(connectionParams.Type) + && !connectionInfo.ConnectionDetails.IsComparableTo(connectionParams.Connection)) + { + return true; + } + return false; + } + + private bool IsDefaultConnectionType(string connectionType) + { + return string.IsNullOrEmpty(connectionType) || ConnectionType.Default.Equals(connectionType, StringComparison.CurrentCultureIgnoreCase); + } + + private void DisconnectExistingConnectionIfNeeded(ConnectParams connectionParams, ConnectionInfo connectionInfo, bool disconnectAll) + { + // Resolve if it is an existing connection + // Disconnect active connection if the URI is already connected for this connection type + DbConnection existingConnection; + if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection)) + { + var disconnectParams = new DisconnectParams() + { + OwnerUri = connectionParams.OwnerUri, + Type = disconnectAll ? null : connectionParams.Type + }; + Disconnect(disconnectParams); + } + } + + /// + /// Creates a ConnectionCompleteParams as a response to a successful connection. + /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo. + /// + /// A ConnectionCompleteParams in response to the successful connection + private ConnectionCompleteParams GetConnectionCompleteParams(string connectionType, ConnectionInfo connectionInfo) + { + ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionType }; + + try + { + DbConnection connection; + connectionInfo.TryGetConnection(connectionType, out connection); + + // Update with the actual database name in connectionInfo and result + // Doing this here as we know the connection is open - expect to do this only on connecting + connectionInfo.ConnectionDetails.DatabaseName = connection.Database; + if (!string.IsNullOrEmpty(connectionInfo.ConnectionDetails.ConnectionString)) + { + // If the connection was set up with a connection string, use the connection string to get the details + var connectionString = new SqlConnectionStringBuilder(connection.ConnectionString); + response.ConnectionSummary = new ConnectionSummary + { + ServerName = connectionString.DataSource, + DatabaseName = connectionString.InitialCatalog, + UserName = connectionString.UserID + }; + } + else + { + response.ConnectionSummary = new ConnectionSummary + { + ServerName = connectionInfo.ConnectionDetails.ServerName, + DatabaseName = connectionInfo.ConnectionDetails.DatabaseName, + UserName = connectionInfo.ConnectionDetails.UserName + }; + } + + response.ConnectionId = connectionInfo.ConnectionId.ToString(); + + var reliableConnection = connection as ReliableSqlConnection; + DbConnection underlyingConnection = reliableConnection != null + ? reliableConnection.GetUnderlyingConnection() + : connection; + + ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(underlyingConnection); + response.ServerInfo = new ServerInfo + { + ServerMajorVersion = serverInfo.ServerMajorVersion, + ServerMinorVersion = serverInfo.ServerMinorVersion, + ServerReleaseVersion = serverInfo.ServerReleaseVersion, + EngineEditionId = serverInfo.EngineEditionId, + ServerVersion = serverInfo.ServerVersion, + ServerLevel = serverInfo.ServerLevel, + ServerEdition = serverInfo.ServerEdition, + IsCloud = serverInfo.IsCloud, + AzureVersion = serverInfo.AzureVersion, + OsVersion = serverInfo.OsVersion + }; + connectionInfo.IsAzure = serverInfo.IsCloud; + connectionInfo.MajorVersion = serverInfo.ServerMajorVersion; + connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse); + } + catch (Exception ex) + { + response.Messages = ex.ToString(); + } + + return response; + } + + /// + /// Tries to create and open a connection with the given ConnectParams. + /// + /// null upon success, a ConnectionCompleteParams detailing the error upon failure + private async Task TryOpenConnection(ConnectionInfo connectionInfo, ConnectParams connectionParams) + { + CancellationTokenSource source = null; + DbConnection connection = null; + CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type }; + ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type }; + + try + { + // build the connection string from the input parameters + string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails); + + // create a sql connection instance + connection = connectionInfo.Factory.CreateSqlConnection(connectionString); + connectionInfo.AddConnection(connectionParams.Type, connection); + + // Add a cancellation token source so that the connection OpenAsync() can be cancelled + source = new CancellationTokenSource(); + // Locking here to perform two operations as one atomic operation + lock (cancellationTokenSourceLock) + { + // If the URI is currently connecting from a different request, cancel it before we try to connect + CancellationTokenSource currentSource; + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource)) + { + currentSource.Cancel(); + } + cancelTupleToCancellationTokenSourceMap[cancelKey] = source; + } + + // Open the connection + await connection.OpenAsync(source.Token); + } + catch (SqlException ex) + { + response.ErrorNumber = ex.Number; + response.ErrorMessage = ex.Message; + response.Messages = ex.ToString(); + return response; + } + catch (OperationCanceledException) + { + // OpenAsync was cancelled + response.Messages = SR.ConnectionServiceConnectionCanceled; + return response; + } + catch (Exception ex) + { + response.ErrorMessage = ex.Message; + response.Messages = ex.ToString(); + return response; + } + finally + { + // Remove our cancellation token from the map since we're no longer connecting + // Using a lock here to perform two operations as one atomic operation + lock (cancellationTokenSourceLock) + { + // Only remove the token from the map if it is the same one created by this request + CancellationTokenSource sourceValue; + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out sourceValue) && sourceValue == source) + { + cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue); + } + source?.Dispose(); + } + } + + // Return null upon success + return null; + } + + /// + /// Gets the existing connection with the given URI and connection type string. If none exists, + /// creates a new connection. This cannot be used to create a default connection or to create a + /// connection if a default connection does not exist. + /// + /// A DB connection for the connection type requested + public async Task GetOrOpenConnection(string ownerUri, string connectionType) + { + Validate.IsNotNullOrEmptyString(nameof(ownerUri), ownerUri); + Validate.IsNotNullOrEmptyString(nameof(connectionType), connectionType); + + // Try to get the ConnectionInfo, if it exists + ConnectionInfo connectionInfo; + if (!ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo)) + { + throw new ArgumentOutOfRangeException(SR.ConnectionServiceListDbErrorNotConnected(ownerUri)); + } + + // Make sure a default connection exists + DbConnection defaultConnection; + if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection)) + { + throw new InvalidOperationException(SR.ConnectionServiceDbErrorDefaultNotConnected(ownerUri)); + } + + // Try to get the DbConnection + DbConnection connection; + if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType) + { + // If the DbConnection does not exist and is not the default connection, create one. + // We can't create the default (initial) connection here because we won't have a ConnectionDetails + // if Connect() has not yet been called. + ConnectParams connectParams = new ConnectParams + { + OwnerUri = ownerUri, + Connection = connectionInfo.ConnectionDetails, + Type = connectionType + }; + await Connect(connectParams); + connectionInfo.TryGetConnection(connectionType, out connection); + } + + return connection; + } + + /// + /// Cancel a connection that is in the process of opening. + /// + public bool CancelConnect(CancelConnectParams cancelParams) + { + // Validate parameters + if (cancelParams == null || string.IsNullOrEmpty(cancelParams.OwnerUri)) + { + return false; + } + + CancelTokenKey cancelKey = new CancelTokenKey + { + OwnerUri = cancelParams.OwnerUri, + Type = cancelParams.Type + }; + + // Cancel any current connection attempts for this URI + CancellationTokenSource source; + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out source)) + { + try + { + source.Cancel(); + return true; + } + catch + { + return false; + } + } + + return false; + } + + /// + /// Close a connection with the specified connection details. + /// public virtual bool Disconnect(DisconnectParams disconnectParams) - { - // Validate parameters - if (disconnectParams == null || string.IsNullOrEmpty(disconnectParams.OwnerUri)) - { - return false; - } - - // Cancel if we are in the middle of connecting - if (CancelConnections(disconnectParams.OwnerUri, disconnectParams.Type)) - { - return false; - } - - // Lookup the ConnectionInfo owned by the URI - ConnectionInfo info; - if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info)) - { - return false; - } - - // Call Close() on the connections we want to disconnect - // If no connections were located, return false - if (!CloseConnections(info, disconnectParams.Type)) - { - return false; - } - - // Remove the disconnected connections from the ConnectionInfo map - if (disconnectParams.Type == null) - { - info.RemoveAllConnections(); - } - else - { - info.RemoveConnection(disconnectParams.Type); - } - - // If the ConnectionInfo has no more connections, remove the ConnectionInfo - if (info.CountConnections == 0) - { - ownerToConnectionMap.Remove(disconnectParams.OwnerUri); - } - - // Handle Telemetry disconnect events if we are disconnecting the default connection - if (disconnectParams.Type == null || disconnectParams.Type == ConnectionType.Default) - { - HandleDisconnectTelemetry(info); - InvokeOnDisconnectionActivities(info); - } - - // Return true upon success - return true; - } - - /// - /// Cancel connections associated with the given ownerUri. - /// If connectionType is not null, cancel the connection with the given connectionType - /// If connectionType is null, cancel all pending connections associated with ownerUri. - /// - /// true if a single pending connection associated with the non-null connectionType was - /// found and cancelled, false otherwise - private bool CancelConnections(string ownerUri, string connectionType) - { - // Cancel the connection of the given type - if (connectionType != null) - { - // If we are trying to disconnect a specific connection and it was just cancelled, - // this will return true - return CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = connectionType }); - } - - // Cancel all pending connections - foreach (var entry in cancelTupleToCancellationTokenSourceMap) - { - string entryConnectionUri = entry.Key.OwnerUri; - string entryConnectionType = entry.Key.Type; - if (ownerUri.Equals(entryConnectionUri)) - { - CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = entryConnectionType }); - } - } - - return false; - } - - /// - /// Closes DbConnections associated with the given ConnectionInfo. - /// If connectionType is not null, closes the DbConnection with the type given by connectionType. - /// If connectionType is null, closes all DbConnections. - /// - /// true if connections were found and attempted to be closed, - /// false if no connections were found - private bool CloseConnections(ConnectionInfo connectionInfo, string connectionType) - { - ICollection connectionsToDisconnect = new List(); - if (connectionType == null) - { - connectionsToDisconnect = connectionInfo.AllConnections; - } - else - { - // Make sure there is an existing connection of this type - DbConnection connection; - if (!connectionInfo.TryGetConnection(connectionType, out connection)) - { - return false; - } - connectionsToDisconnect.Add(connection); - } - - if (connectionsToDisconnect.Count == 0) - { - return false; - } - - foreach (DbConnection connection in connectionsToDisconnect) - { - try - { - connection.Close(); - } - catch (Exception) - { - // Ignore - } - } - - return true; - } - - /// - /// List all databases on the server specified - /// - public ListDatabasesResponse ListDatabases(ListDatabasesParams listDatabasesParams) - { - // Verify parameters - var owner = listDatabasesParams.OwnerUri; - if (string.IsNullOrEmpty(owner)) - { - throw new ArgumentException(SR.ConnectionServiceListDbErrorNullOwnerUri); - } - - // Use the existing connection as a base for the search - ConnectionInfo info; - if (!TryFindConnection(owner, out info)) - { - throw new Exception(SR.ConnectionServiceListDbErrorNotConnected(owner)); - } - ConnectionDetails connectionDetails = info.ConnectionDetails.Clone(); - - // Connect to master and query sys.databases - connectionDetails.DatabaseName = "master"; - var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails)); - connection.Open(); - - List results = new List(); - var systemDatabases = new[] {"master", "model", "msdb", "tempdb"}; - using (DbCommand command = connection.CreateCommand()) - { - command.CommandText = "SELECT name FROM sys.databases ORDER BY name ASC"; - command.CommandTimeout = 15; - command.CommandType = CommandType.Text; - - using (var reader = command.ExecuteReader()) - { - while (reader.Read()) - { - results.Add(reader[0].ToString()); - } - } - } - - // Put system databases at the top of the list - results = - results.Where(s => systemDatabases.Any(s.Equals)).Concat( - results.Where(s => systemDatabases.All(x => !s.Equals(x)))).ToList(); - - connection.Close(); - - ListDatabasesResponse response = new ListDatabasesResponse(); - response.DatabaseNames = results.ToArray(); - - return response; - } - - public void InitializeService(IProtocolEndpoint serviceHost) - { - this.ServiceHost = serviceHost; - - // Register request and event handlers with the Service Host - serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest); - serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest); - serviceHost.SetRequestHandler(DisconnectRequest.Type, HandleDisconnectRequest); - serviceHost.SetRequestHandler(ListDatabasesRequest.Type, HandleListDatabasesRequest); - - // Register the configuration update handler - WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification); - } - - /// - /// Add a new method to be called when the onconnection request is submitted - /// - /// - public void RegisterOnConnectionTask(OnConnectionHandler activity) - { - onConnectionActivities.Add(activity); - } - - /// - /// Add a new method to be called when the ondisconnect request is submitted - /// - public void RegisterOnDisconnectTask(OnDisconnectHandler activity) - { - onDisconnectActivities.Add(activity); - } - - /// - /// Handle new connection requests - /// - /// - /// - /// - protected async Task HandleConnectRequest( - ConnectParams connectParams, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleConnectRequest"); - - try - { - RunConnectRequestHandlerTask(connectParams); - await requestContext.SendResult(true); - } - catch - { - await requestContext.SendResult(false); - } - } - - private void RunConnectRequestHandlerTask(ConnectParams connectParams) - { - // create a task to connect asynchronously so that other requests are not blocked in the meantime - Task.Run(async () => - { - try - { - // result is null if the ConnectParams was successfully validated - ConnectionCompleteParams result = ValidateConnectParams(connectParams); - if (result != null) - { - await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); - return; - } - - // open connection based on request details - result = await Connect(connectParams); - await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); - } - catch (Exception ex) - { - ConnectionCompleteParams result = new ConnectionCompleteParams() - { - Messages = ex.ToString() - }; - await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); - } - }); - } - - /// - /// Handle cancel connect requests - /// - protected async Task HandleCancelConnectRequest( - CancelConnectParams cancelParams, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleCancelConnectRequest"); - - try - { - bool result = CancelConnect(cancelParams); - await requestContext.SendResult(result); - } - catch(Exception ex) - { - await requestContext.SendError(ex.ToString()); - } - } - - /// - /// Handle disconnect requests - /// - protected async Task HandleDisconnectRequest( - DisconnectParams disconnectParams, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "HandleDisconnectRequest"); - - try - { - bool result = Instance.Disconnect(disconnectParams); - await requestContext.SendResult(result); - } - catch(Exception ex) - { - await requestContext.SendError(ex.ToString()); - } - - } - - /// - /// Handle requests to list databases on the current server - /// - protected async Task HandleListDatabasesRequest( - ListDatabasesParams listDatabasesParams, - RequestContext requestContext) - { - Logger.Write(LogLevel.Verbose, "ListDatabasesRequest"); - - try - { - ListDatabasesResponse result = Instance.ListDatabases(listDatabasesParams); - await requestContext.SendResult(result); - } - catch(Exception ex) - { - await requestContext.SendError(ex.ToString()); - } - } - - public Task HandleDidChangeConfigurationNotification( - SqlToolsSettings newSettings, - SqlToolsSettings oldSettings, - EventContext eventContext) - { - return Task.FromResult(true); - } - - /// - /// Build a connection string from a connection details instance - /// - /// - public static string BuildConnectionString(ConnectionDetails connectionDetails) - { - SqlConnectionStringBuilder connectionBuilder = new SqlConnectionStringBuilder - { - ["Data Source"] = connectionDetails.ServerName, - ["User Id"] = connectionDetails.UserName, - ["Password"] = connectionDetails.Password - }; - - // Check for any optional parameters - if (!string.IsNullOrEmpty(connectionDetails.DatabaseName)) - { - connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName; - } - if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType)) - { - switch(connectionDetails.AuthenticationType) - { - case "Integrated": - connectionBuilder.IntegratedSecurity = true; - break; - case "SqlLogin": - break; - default: - throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType)); - } - } - if (connectionDetails.Encrypt.HasValue) - { - connectionBuilder.Encrypt = connectionDetails.Encrypt.Value; - } - if (connectionDetails.TrustServerCertificate.HasValue) - { - connectionBuilder.TrustServerCertificate = connectionDetails.TrustServerCertificate.Value; - } - if (connectionDetails.PersistSecurityInfo.HasValue) - { - connectionBuilder.PersistSecurityInfo = connectionDetails.PersistSecurityInfo.Value; - } - if (connectionDetails.ConnectTimeout.HasValue) - { - connectionBuilder.ConnectTimeout = connectionDetails.ConnectTimeout.Value; - } - if (connectionDetails.ConnectRetryCount.HasValue) - { - connectionBuilder.ConnectRetryCount = connectionDetails.ConnectRetryCount.Value; - } - if (connectionDetails.ConnectRetryInterval.HasValue) - { - connectionBuilder.ConnectRetryInterval = connectionDetails.ConnectRetryInterval.Value; - } - if (!string.IsNullOrEmpty(connectionDetails.ApplicationName)) - { - connectionBuilder.ApplicationName = connectionDetails.ApplicationName; - } - if (!string.IsNullOrEmpty(connectionDetails.WorkstationId)) - { - connectionBuilder.WorkstationID = connectionDetails.WorkstationId; - } - if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent)) - { - ApplicationIntent intent; - switch (connectionDetails.ApplicationIntent) - { - case "ReadOnly": - intent = ApplicationIntent.ReadOnly; - break; - case "ReadWrite": - intent = ApplicationIntent.ReadWrite; - break; - default: - throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent)); - } - connectionBuilder.ApplicationIntent = intent; - } - if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage)) - { - connectionBuilder.CurrentLanguage = connectionDetails.CurrentLanguage; - } - if (connectionDetails.Pooling.HasValue) - { - connectionBuilder.Pooling = connectionDetails.Pooling.Value; - } - if (connectionDetails.MaxPoolSize.HasValue) - { - connectionBuilder.MaxPoolSize = connectionDetails.MaxPoolSize.Value; - } - if (connectionDetails.MinPoolSize.HasValue) - { - connectionBuilder.MinPoolSize = connectionDetails.MinPoolSize.Value; - } - if (connectionDetails.LoadBalanceTimeout.HasValue) - { - connectionBuilder.LoadBalanceTimeout = connectionDetails.LoadBalanceTimeout.Value; - } - if (connectionDetails.Replication.HasValue) - { - connectionBuilder.Replication = connectionDetails.Replication.Value; - } - if (!string.IsNullOrEmpty(connectionDetails.AttachDbFilename)) - { - connectionBuilder.AttachDBFilename = connectionDetails.AttachDbFilename; - } - if (!string.IsNullOrEmpty(connectionDetails.FailoverPartner)) - { - connectionBuilder.FailoverPartner = connectionDetails.FailoverPartner; - } - if (connectionDetails.MultiSubnetFailover.HasValue) - { - connectionBuilder.MultiSubnetFailover = connectionDetails.MultiSubnetFailover.Value; - } - if (connectionDetails.MultipleActiveResultSets.HasValue) - { - connectionBuilder.MultipleActiveResultSets = connectionDetails.MultipleActiveResultSets.Value; - } - if (connectionDetails.PacketSize.HasValue) - { - connectionBuilder.PacketSize = connectionDetails.PacketSize.Value; - } - if (!string.IsNullOrEmpty(connectionDetails.TypeSystemVersion)) - { - connectionBuilder.TypeSystemVersion = connectionDetails.TypeSystemVersion; - } - - return connectionBuilder.ToString(); - } - - /// - /// Change the database context of a connection. - /// - /// URI of the owner of the connection - /// Name of the database to change the connection to - public void ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName) - { - ConnectionInfo info; - if (TryFindConnection(ownerUri, out info)) - { - try - { - foreach (DbConnection connection in info.AllConnections) - { - if (connection.State == ConnectionState.Open) - { - connection.ChangeDatabase(newDatabaseName); - } - } - - info.ConnectionDetails.DatabaseName = newDatabaseName; - - // Fire a connection changed event - ConnectionChangedParams parameters = new ConnectionChangedParams(); - ConnectionSummary summary = info.ConnectionDetails; - parameters.Connection = summary.Clone(); - parameters.OwnerUri = ownerUri; - ServiceHost.SendEvent(ConnectionChangedNotification.Type, parameters); - } - catch (Exception e) - { - Logger.Write( - LogLevel.Error, - string.Format( - "Exception caught while trying to change database context to [{0}] for OwnerUri [{1}]. Exception:{2}", - newDatabaseName, - ownerUri, - e.ToString()) - ); - } - } - } - - /// - /// Invokes the initial on-connect activities if the provided ConnectParams represents the default - /// connection. - /// - private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams) - { - if (connectParams.Type != ConnectionType.Default) - { - return; - } - - foreach (var activity in this.onConnectionActivities) - { - // not awaiting here to allow handlers to run in the background - activity(connectionInfo); - } - } - - /// - /// Invokes the final on-disconnect activities if the provided DisconnectParams represents the default - /// connection or is null - representing that all connections are being disconnected. - /// - private void InvokeOnDisconnectionActivities(ConnectionInfo connectionInfo) - { - foreach (var activity in this.onDisconnectActivities) - { - activity(connectionInfo.ConnectionDetails, connectionInfo.OwnerUri); - } - } - - /// - /// Handles the Telemetry events that occur upon disconnect. - /// - /// - private void HandleDisconnectTelemetry(ConnectionInfo connectionInfo) - { - if (ServiceHost != null) - { - try - { - // Send a telemetry notification for intellisense performance metrics - ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams() - { - Params = new TelemetryProperties - { - Properties = new Dictionary - { - { TelemetryPropertyNames.IsAzure, connectionInfo.IsAzure.ToOneOrZeroString() } - }, - EventName = TelemetryEventNames.IntellisenseQuantile, - Measures = connectionInfo.IntellisenseMetrics.Quantile - } - }); - } - catch (Exception ex) - { - Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString()); - } - } - } - } -} + { + // Validate parameters + if (disconnectParams == null || string.IsNullOrEmpty(disconnectParams.OwnerUri)) + { + return false; + } + + // Cancel if we are in the middle of connecting + if (CancelConnections(disconnectParams.OwnerUri, disconnectParams.Type)) + { + return false; + } + + // Lookup the ConnectionInfo owned by the URI + ConnectionInfo info; + if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info)) + { + return false; + } + + // Call Close() on the connections we want to disconnect + // If no connections were located, return false + if (!CloseConnections(info, disconnectParams.Type)) + { + return false; + } + + // Remove the disconnected connections from the ConnectionInfo map + if (disconnectParams.Type == null) + { + info.RemoveAllConnections(); + } + else + { + info.RemoveConnection(disconnectParams.Type); + } + + // If the ConnectionInfo has no more connections, remove the ConnectionInfo + if (info.CountConnections == 0) + { + ownerToConnectionMap.Remove(disconnectParams.OwnerUri); + } + + // Handle Telemetry disconnect events if we are disconnecting the default connection + if (disconnectParams.Type == null || disconnectParams.Type == ConnectionType.Default) + { + HandleDisconnectTelemetry(info); + InvokeOnDisconnectionActivities(info); + } + + // Return true upon success + return true; + } + + /// + /// Cancel connections associated with the given ownerUri. + /// If connectionType is not null, cancel the connection with the given connectionType + /// If connectionType is null, cancel all pending connections associated with ownerUri. + /// + /// true if a single pending connection associated with the non-null connectionType was + /// found and cancelled, false otherwise + private bool CancelConnections(string ownerUri, string connectionType) + { + // Cancel the connection of the given type + if (connectionType != null) + { + // If we are trying to disconnect a specific connection and it was just cancelled, + // this will return true + return CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = connectionType }); + } + + // Cancel all pending connections + foreach (var entry in cancelTupleToCancellationTokenSourceMap) + { + string entryConnectionUri = entry.Key.OwnerUri; + string entryConnectionType = entry.Key.Type; + if (ownerUri.Equals(entryConnectionUri)) + { + CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = entryConnectionType }); + } + } + + return false; + } + + /// + /// Closes DbConnections associated with the given ConnectionInfo. + /// If connectionType is not null, closes the DbConnection with the type given by connectionType. + /// If connectionType is null, closes all DbConnections. + /// + /// true if connections were found and attempted to be closed, + /// false if no connections were found + private bool CloseConnections(ConnectionInfo connectionInfo, string connectionType) + { + ICollection connectionsToDisconnect = new List(); + if (connectionType == null) + { + connectionsToDisconnect = connectionInfo.AllConnections; + } + else + { + // Make sure there is an existing connection of this type + DbConnection connection; + if (!connectionInfo.TryGetConnection(connectionType, out connection)) + { + return false; + } + connectionsToDisconnect.Add(connection); + } + + if (connectionsToDisconnect.Count == 0) + { + return false; + } + + foreach (DbConnection connection in connectionsToDisconnect) + { + try + { + connection.Close(); + } + catch (Exception) + { + // Ignore + } + } + + return true; + } + + /// + /// List all databases on the server specified + /// + public ListDatabasesResponse ListDatabases(ListDatabasesParams listDatabasesParams) + { + // Verify parameters + var owner = listDatabasesParams.OwnerUri; + if (string.IsNullOrEmpty(owner)) + { + throw new ArgumentException(SR.ConnectionServiceListDbErrorNullOwnerUri); + } + + // Use the existing connection as a base for the search + ConnectionInfo info; + if (!TryFindConnection(owner, out info)) + { + throw new Exception(SR.ConnectionServiceListDbErrorNotConnected(owner)); + } + ConnectionDetails connectionDetails = info.ConnectionDetails.Clone(); + + // Connect to master and query sys.databases + connectionDetails.DatabaseName = "master"; + var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails)); + connection.Open(); + + List results = new List(); + var systemDatabases = new[] {"master", "model", "msdb", "tempdb"}; + using (DbCommand command = connection.CreateCommand()) + { + command.CommandText = "SELECT name FROM sys.databases ORDER BY name ASC"; + command.CommandTimeout = 15; + command.CommandType = CommandType.Text; + + using (var reader = command.ExecuteReader()) + { + while (reader.Read()) + { + results.Add(reader[0].ToString()); + } + } + } + + // Put system databases at the top of the list + results = + results.Where(s => systemDatabases.Any(s.Equals)).Concat( + results.Where(s => systemDatabases.All(x => !s.Equals(x)))).ToList(); + + connection.Close(); + + ListDatabasesResponse response = new ListDatabasesResponse(); + response.DatabaseNames = results.ToArray(); + + return response; + } + + public void InitializeService(IProtocolEndpoint serviceHost) + { + this.ServiceHost = serviceHost; + + // Register request and event handlers with the Service Host + serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest); + serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest); + serviceHost.SetRequestHandler(DisconnectRequest.Type, HandleDisconnectRequest); + serviceHost.SetRequestHandler(ListDatabasesRequest.Type, HandleListDatabasesRequest); + + // Register the configuration update handler + WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification); + } + + /// + /// Add a new method to be called when the onconnection request is submitted + /// + /// + public void RegisterOnConnectionTask(OnConnectionHandler activity) + { + onConnectionActivities.Add(activity); + } + + /// + /// Add a new method to be called when the ondisconnect request is submitted + /// + public void RegisterOnDisconnectTask(OnDisconnectHandler activity) + { + onDisconnectActivities.Add(activity); + } + + /// + /// Handle new connection requests + /// + /// + /// + /// + protected async Task HandleConnectRequest( + ConnectParams connectParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleConnectRequest"); + + try + { + RunConnectRequestHandlerTask(connectParams); + await requestContext.SendResult(true); + } + catch + { + await requestContext.SendResult(false); + } + } + + private void RunConnectRequestHandlerTask(ConnectParams connectParams) + { + // create a task to connect asynchronously so that other requests are not blocked in the meantime + Task.Run(async () => + { + try + { + // result is null if the ConnectParams was successfully validated + ConnectionCompleteParams result = ValidateConnectParams(connectParams); + if (result != null) + { + await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); + return; + } + + // open connection based on request details + result = await Connect(connectParams); + await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); + } + catch (Exception ex) + { + ConnectionCompleteParams result = new ConnectionCompleteParams() + { + Messages = ex.ToString() + }; + await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); + } + }); + } + + /// + /// Handle cancel connect requests + /// + protected async Task HandleCancelConnectRequest( + CancelConnectParams cancelParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleCancelConnectRequest"); + + try + { + bool result = CancelConnect(cancelParams); + await requestContext.SendResult(result); + } + catch(Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + } + + /// + /// Handle disconnect requests + /// + protected async Task HandleDisconnectRequest( + DisconnectParams disconnectParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleDisconnectRequest"); + + try + { + bool result = Instance.Disconnect(disconnectParams); + await requestContext.SendResult(result); + } + catch(Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + + } + + /// + /// Handle requests to list databases on the current server + /// + protected async Task HandleListDatabasesRequest( + ListDatabasesParams listDatabasesParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "ListDatabasesRequest"); + + try + { + ListDatabasesResponse result = Instance.ListDatabases(listDatabasesParams); + await requestContext.SendResult(result); + } + catch(Exception ex) + { + await requestContext.SendError(ex.ToString()); + } + } + + public Task HandleDidChangeConfigurationNotification( + SqlToolsSettings newSettings, + SqlToolsSettings oldSettings, + EventContext eventContext) + { + return Task.FromResult(true); + } + + /// + /// Build a connection string from a connection details instance + /// + /// + public static string BuildConnectionString(ConnectionDetails connectionDetails) + { + SqlConnectionStringBuilder connectionBuilder; + + // If connectionDetails has a connection string already, just validate and return it + if (!string.IsNullOrEmpty(connectionDetails.ConnectionString)) + { + connectionBuilder = new SqlConnectionStringBuilder(connectionDetails.ConnectionString); + return connectionBuilder.ToString(); + } + + connectionBuilder = new SqlConnectionStringBuilder + { + ["Data Source"] = connectionDetails.ServerName, + ["User Id"] = connectionDetails.UserName, + ["Password"] = connectionDetails.Password + }; + + // Check for any optional parameters + if (!string.IsNullOrEmpty(connectionDetails.DatabaseName)) + { + connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName; + } + if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType)) + { + switch(connectionDetails.AuthenticationType) + { + case "Integrated": + connectionBuilder.IntegratedSecurity = true; + break; + case "SqlLogin": + break; + default: + throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType)); + } + } + if (connectionDetails.Encrypt.HasValue) + { + connectionBuilder.Encrypt = connectionDetails.Encrypt.Value; + } + if (connectionDetails.TrustServerCertificate.HasValue) + { + connectionBuilder.TrustServerCertificate = connectionDetails.TrustServerCertificate.Value; + } + if (connectionDetails.PersistSecurityInfo.HasValue) + { + connectionBuilder.PersistSecurityInfo = connectionDetails.PersistSecurityInfo.Value; + } + if (connectionDetails.ConnectTimeout.HasValue) + { + connectionBuilder.ConnectTimeout = connectionDetails.ConnectTimeout.Value; + } + if (connectionDetails.ConnectRetryCount.HasValue) + { + connectionBuilder.ConnectRetryCount = connectionDetails.ConnectRetryCount.Value; + } + if (connectionDetails.ConnectRetryInterval.HasValue) + { + connectionBuilder.ConnectRetryInterval = connectionDetails.ConnectRetryInterval.Value; + } + if (!string.IsNullOrEmpty(connectionDetails.ApplicationName)) + { + connectionBuilder.ApplicationName = connectionDetails.ApplicationName; + } + if (!string.IsNullOrEmpty(connectionDetails.WorkstationId)) + { + connectionBuilder.WorkstationID = connectionDetails.WorkstationId; + } + if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent)) + { + ApplicationIntent intent; + switch (connectionDetails.ApplicationIntent) + { + case "ReadOnly": + intent = ApplicationIntent.ReadOnly; + break; + case "ReadWrite": + intent = ApplicationIntent.ReadWrite; + break; + default: + throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent)); + } + connectionBuilder.ApplicationIntent = intent; + } + if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage)) + { + connectionBuilder.CurrentLanguage = connectionDetails.CurrentLanguage; + } + if (connectionDetails.Pooling.HasValue) + { + connectionBuilder.Pooling = connectionDetails.Pooling.Value; + } + if (connectionDetails.MaxPoolSize.HasValue) + { + connectionBuilder.MaxPoolSize = connectionDetails.MaxPoolSize.Value; + } + if (connectionDetails.MinPoolSize.HasValue) + { + connectionBuilder.MinPoolSize = connectionDetails.MinPoolSize.Value; + } + if (connectionDetails.LoadBalanceTimeout.HasValue) + { + connectionBuilder.LoadBalanceTimeout = connectionDetails.LoadBalanceTimeout.Value; + } + if (connectionDetails.Replication.HasValue) + { + connectionBuilder.Replication = connectionDetails.Replication.Value; + } + if (!string.IsNullOrEmpty(connectionDetails.AttachDbFilename)) + { + connectionBuilder.AttachDBFilename = connectionDetails.AttachDbFilename; + } + if (!string.IsNullOrEmpty(connectionDetails.FailoverPartner)) + { + connectionBuilder.FailoverPartner = connectionDetails.FailoverPartner; + } + if (connectionDetails.MultiSubnetFailover.HasValue) + { + connectionBuilder.MultiSubnetFailover = connectionDetails.MultiSubnetFailover.Value; + } + if (connectionDetails.MultipleActiveResultSets.HasValue) + { + connectionBuilder.MultipleActiveResultSets = connectionDetails.MultipleActiveResultSets.Value; + } + if (connectionDetails.PacketSize.HasValue) + { + connectionBuilder.PacketSize = connectionDetails.PacketSize.Value; + } + if (!string.IsNullOrEmpty(connectionDetails.TypeSystemVersion)) + { + connectionBuilder.TypeSystemVersion = connectionDetails.TypeSystemVersion; + } + + return connectionBuilder.ToString(); + } + + /// + /// Change the database context of a connection. + /// + /// URI of the owner of the connection + /// Name of the database to change the connection to + public void ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName) + { + ConnectionInfo info; + if (TryFindConnection(ownerUri, out info)) + { + try + { + foreach (DbConnection connection in info.AllConnections) + { + if (connection.State == ConnectionState.Open) + { + connection.ChangeDatabase(newDatabaseName); + } + } + + info.ConnectionDetails.DatabaseName = newDatabaseName; + + // Fire a connection changed event + ConnectionChangedParams parameters = new ConnectionChangedParams(); + ConnectionSummary summary = info.ConnectionDetails; + parameters.Connection = summary.Clone(); + parameters.OwnerUri = ownerUri; + ServiceHost.SendEvent(ConnectionChangedNotification.Type, parameters); + } + catch (Exception e) + { + Logger.Write( + LogLevel.Error, + string.Format( + "Exception caught while trying to change database context to [{0}] for OwnerUri [{1}]. Exception:{2}", + newDatabaseName, + ownerUri, + e.ToString()) + ); + } + } + } + + /// + /// Invokes the initial on-connect activities if the provided ConnectParams represents the default + /// connection. + /// + private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams) + { + if (connectParams.Type != ConnectionType.Default) + { + return; + } + + foreach (var activity in this.onConnectionActivities) + { + // not awaiting here to allow handlers to run in the background + activity(connectionInfo); + } + } + + /// + /// Invokes the final on-disconnect activities if the provided DisconnectParams represents the default + /// connection or is null - representing that all connections are being disconnected. + /// + private void InvokeOnDisconnectionActivities(ConnectionInfo connectionInfo) + { + foreach (var activity in this.onDisconnectActivities) + { + activity(connectionInfo.ConnectionDetails, connectionInfo.OwnerUri); + } + } + + /// + /// Handles the Telemetry events that occur upon disconnect. + /// + /// + private void HandleDisconnectTelemetry(ConnectionInfo connectionInfo) + { + if (ServiceHost != null) + { + try + { + // Send a telemetry notification for intellisense performance metrics + ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams() + { + Params = new TelemetryProperties + { + Properties = new Dictionary + { + { TelemetryPropertyNames.IsAzure, connectionInfo.IsAzure.ToOneOrZeroString() } + }, + EventName = TelemetryEventNames.IntellisenseQuantile, + Measures = connectionInfo.IntellisenseMetrics.Quantile + } + }); + } + catch (Exception ex) + { + Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString()); + } + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs index 58d38ee4..3333ce00 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs @@ -14,7 +14,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace /// or a virtual file representing an object in a database. /// - public string OwnerUri { get; set; } + public string OwnerUri { get; set; } + /// /// Contains the required parameters to initialize a connection to a database. /// A connection will identified by its server name, database name and user name. diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs index 7b581658..6b191dde 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs @@ -24,6 +24,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { errorMessage = SR.ConnectionParamsValidateNullConnection; } + else if (!string.IsNullOrEmpty(parameters.Connection.ConnectionString)) + { + // Do not check other connection parameters if a connection string is present + return string.IsNullOrEmpty(errorMessage); + } else if (string.IsNullOrEmpty(parameters.Connection.ServerName)) { errorMessage = SR.ConnectionParamsValidateNullServerName; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs index 68dfa51f..57312cbf 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -5,7 +5,7 @@ using System; using System.Collections.Generic; -using System.Globalization; +using System.Globalization; using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts @@ -443,6 +443,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts } } + /// + /// Gets or sets a string value to be used as the connection string. If given, all other options will be ignored. + /// + public string ConnectionString + { + get + { + return GetOptionValue("connectionString"); + } + + set + { + SetOptionValue("connectionString", value); + } + } + private T GetOptionValue(string name) { T result = default(T); @@ -485,33 +501,33 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { Options.Add(name, value); } - } - - public bool IsComparableTo(ConnectionDetails other) - { - if (other == null) - { - return false; - } - - if (!string.Equals(ServerName, other.ServerName) - || !string.Equals(AuthenticationType, other.AuthenticationType) - || !string.Equals(UserName, other.UserName)) - { - return false; - } - - // For database name, only compare if neither is empty. This is important - // Since it allows for handling of connections to the default database, but is - // not a 100% accurate heuristic. - if (!string.IsNullOrEmpty(DatabaseName) - && !string.IsNullOrEmpty(other.DatabaseName) - && !string.Equals(DatabaseName, other.DatabaseName)) - { - return false; - } - - return true; - } + } + + public bool IsComparableTo(ConnectionDetails other) + { + if (other == null) + { + return false; + } + + if (ServerName != other.ServerName + || AuthenticationType != other.AuthenticationType + || UserName != other.UserName) + { + return false; + } + + // For database name, only compare if neither is empty. This is important + // Since it allows for handling of connections to the default database, but is + // not a 100% accurate heuristic. + if (!string.IsNullOrEmpty(DatabaseName) + && !string.IsNullOrEmpty(other.DatabaseName) + && DatabaseName != other.DatabaseName) + { + return false; + } + + return true; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 690f7577..9e279b60 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -325,10 +325,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection .Returns((string connString) => { dummySqlConnection.ConnectionString = connString; - SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); - - // Database name is respected. Follow heuristic where empty DB name really means Master - var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; + SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); + + // Database name is respected. Follow heuristic where empty DB name really means Master + var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; dummySqlConnection.SetDatabase(dbName); return dummySqlConnection; }); @@ -1210,5 +1210,37 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection await Assert.ThrowsAsync( () => service.GetOrOpenConnection(TestObjects.ScriptUri, ConnectionType.Query)); } + + [Fact] + public async Task ConnectionWithConnectionStringSucceeds() + { + var connectionParameters = TestObjects.GetTestConnectionParams(true); + var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); + + Assert.NotEmpty(connectionResult.ConnectionId); + } + + [Fact] + public async Task ConnectionWithBadConnectionStringFails() + { + var connectionParameters = TestObjects.GetTestConnectionParams(true); + connectionParameters.Connection.ConnectionString = "thisisnotavalidconnectionstring"; + var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); + + Assert.NotEmpty(connectionResult.ErrorMessage); + } + + [Fact] + public async Task ConnectionWithConnectionStringOverridesParameters() + { + var connectionParameters = TestObjects.GetTestConnectionParams(); + connectionParameters.Connection.ServerName = "overriddenServerName"; + var connectionString = TestObjects.GetTestConnectionParams(true).Connection.ConnectionString; + connectionParameters.Connection.ConnectionString = connectionString; + + // Connect and verify that the server name has been overridden + var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); + Assert.NotEqual(connectionParameters.Connection.ServerName, connectionResult.ConnectionSummary.ServerName); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs index 240877e2..c9cf078e 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs @@ -44,12 +44,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility GetTestConnectionDetails()); } - public static ConnectParams GetTestConnectionParams() + public static ConnectParams GetTestConnectionParams(bool useConnectionString = false) { return new ConnectParams() { OwnerUri = ScriptUri, - Connection = GetTestConnectionDetails() + Connection = GetTestConnectionDetails(useConnectionString) }; } @@ -66,8 +66,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility ServerEdition = "Developer Edition", ServerLevel = "" }; - } - + } + /// /// Creates a test sql connection factory instance /// @@ -80,8 +80,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility /// /// Creates a test connection details object /// - public static ConnectionDetails GetTestConnectionDetails() + public static ConnectionDetails GetTestConnectionDetails(bool useConnectionString = false) { + if (useConnectionString) + { + return new ConnectionDetails() + { + ConnectionString = "User ID=user;PWD=password;Database=databaseName;Server=serverName" + }; + } + return new ConnectionDetails() { UserName = "user", @@ -214,9 +222,9 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility } public override string ConnectionString { get; set; } - public override string Database + public override string Database { - get { return _database; } + get { return _database; } } public override ConnectionState State { get; } @@ -233,13 +241,13 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility // No Op } - /// - /// Test helper method to set the database value - /// + /// + /// Test helper method to set the database value + /// /// - public void SetDatabase(string database) - { - this._database = database; + public void SetDatabase(string database) + { + this._database = database; } }