// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; using System.Threading.Tasks; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlTools.EditorServices.Utility; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Workspace; namespace Microsoft.SqlTools.ServiceLayer.Connection { /// /// Main class for the Connection Management services /// public class ConnectionService { /// /// Singleton service instance /// private static Lazy instance = new Lazy(() => new ConnectionService()); /// /// Gets the singleton service instance /// public static ConnectionService Instance { get { return instance.Value; } } /// /// The SQL connection factory object /// private ISqlConnectionFactory connectionFactory; private Dictionary ownerToConnectionMap = new Dictionary(); /// /// Service host object for sending/receiving requests/events. /// Internal for testing purposes. /// internal IProtocolEndpoint ServiceHost { get; set; } /// /// Default constructor is private since it's a singleton class /// private ConnectionService() { } /// /// Callback for onconnection handler /// /// public delegate Task OnConnectionHandler(ConnectionInfo info); /// // Callback for ondisconnect handler /// public delegate Task OnDisconnectHandler(ConnectionSummary summary); /// /// List of onconnection handlers /// private readonly List onConnectionActivities = new List(); /// /// List of ondisconnect handlers /// private readonly List onDisconnectActivities = new List(); /// /// Gets the SQL connection factory instance /// public ISqlConnectionFactory ConnectionFactory { get { if (this.connectionFactory == null) { this.connectionFactory = new SqlConnectionFactory(); } return this.connectionFactory; } } /// /// Test constructor that injects dependency interfaces /// /// public ConnectionService(ISqlConnectionFactory testFactory) { this.connectionFactory = testFactory; } // Attempts to link a URI to an actively used connection for this URI public bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo) { return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo); } /// /// Open a connection with the specified connection details /// /// public ConnectResponse Connect(ConnectParams connectionParams) { // Validate parameters string paramValidationErrorMessage; if (connectionParams == null) { return new ConnectResponse() { Messages = "Error: Connection parameters cannot be null." }; } else if (!connectionParams.IsValid(out paramValidationErrorMessage)) { return new ConnectResponse() { Messages = paramValidationErrorMessage }; } // Resolve if it is an existing connection // Disconnect active connection if the URI is already connected ConnectionInfo connectionInfo; if (ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo) ) { var disconnectParams = new DisconnectParams() { OwnerUri = connectionParams.OwnerUri }; Disconnect(disconnectParams); } connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); // try to connect var response = new ConnectResponse(); try { // build the connection string from the input parameters string connectionString = ConnectionService.BuildConnectionString(connectionInfo.ConnectionDetails); // create a sql connection instance connectionInfo.SqlConnection = connectionInfo.Factory.CreateSqlConnection(connectionString); connectionInfo.SqlConnection.Open(); } catch(Exception ex) { response.Messages = ex.ToString(); return response; } ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; // invoke callback notifications foreach (var activity in this.onConnectionActivities) { activity(connectionInfo); } // return the connection result response.ConnectionId = connectionInfo.ConnectionId.ToString(); return response; } /// /// Close a connection with the specified connection details. /// public bool Disconnect(DisconnectParams disconnectParams) { // Validate parameters if (disconnectParams == null || string.IsNullOrEmpty(disconnectParams.OwnerUri)) { return false; } // Lookup the connection owned by the URI ConnectionInfo info; if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info)) { return false; } // Close the connection info.SqlConnection.Close(); // Remove URI mapping ownerToConnectionMap.Remove(disconnectParams.OwnerUri); // Invoke callback notifications foreach (var activity in this.onDisconnectActivities) { activity(info.ConnectionDetails); } // Success return true; } /// /// List all databases on the server specified /// public ListDatabasesResponse ListDatabases(ListDatabasesParams listDatabasesParams) { // Verify parameters var owner = listDatabasesParams.OwnerUri; if (string.IsNullOrEmpty(owner)) { throw new ArgumentException("OwnerUri cannot be null or empty"); } // Use the existing connection as a base for the search ConnectionInfo info; if (!TryFindConnection(owner, out info)) { throw new Exception("Specified OwnerUri \"" + owner + "\" does not have an existing connection"); } ConnectionDetails connectionDetails = info.ConnectionDetails.Clone(); // Connect to master and query sys.databases connectionDetails.DatabaseName = "master"; var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails)); connection.Open(); DbCommand command = connection.CreateCommand(); command.CommandText = "SELECT name FROM sys.databases"; command.CommandTimeout = 15; command.CommandType = CommandType.Text; var reader = command.ExecuteReader(); List results = new List(); while (reader.Read()) { results.Add(reader[0].ToString()); } connection.Close(); ListDatabasesResponse response = new ListDatabasesResponse(); response.DatabaseNames = results.ToArray(); return response; } public void InitializeService(IProtocolEndpoint serviceHost) { this.ServiceHost = serviceHost; // Register request and event handlers with the Service Host serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest); serviceHost.SetRequestHandler(DisconnectRequest.Type, HandleDisconnectRequest); serviceHost.SetRequestHandler(ListDatabasesRequest.Type, HandleListDatabasesRequest); // Register the configuration update handler WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification); } /// /// Add a new method to be called when the onconnection request is submitted /// /// public void RegisterOnConnectionTask(OnConnectionHandler activity) { onConnectionActivities.Add(activity); } /// /// Add a new method to be called when the ondisconnect request is submitted /// public void RegisterOnDisconnectTask(OnDisconnectHandler activity) { onDisconnectActivities.Add(activity); } /// /// Handle new connection requests /// /// /// /// protected async Task HandleConnectRequest( ConnectParams connectParams, RequestContext requestContext) { Logger.Write(LogLevel.Verbose, "HandleConnectRequest"); try { // open connection base on request details ConnectResponse result = ConnectionService.Instance.Connect(connectParams); await requestContext.SendResult(result); } catch(Exception ex) { await requestContext.SendError(ex.ToString()); } } /// /// Handle disconnect requests /// protected async Task HandleDisconnectRequest( DisconnectParams disconnectParams, RequestContext requestContext) { Logger.Write(LogLevel.Verbose, "HandleDisconnectRequest"); try { bool result = ConnectionService.Instance.Disconnect(disconnectParams); await requestContext.SendResult(result); } catch(Exception ex) { await requestContext.SendError(ex.ToString()); } } /// /// Handle requests to list databases on the current server /// protected async Task HandleListDatabasesRequest( ListDatabasesParams listDatabasesParams, RequestContext requestContext) { Logger.Write(LogLevel.Verbose, "ListDatabasesRequest"); try { ListDatabasesResponse result = ConnectionService.Instance.ListDatabases(listDatabasesParams); await requestContext.SendResult(result); } catch(Exception ex) { await requestContext.SendError(ex.ToString()); } } public Task HandleDidChangeConfigurationNotification( SqlToolsSettings newSettings, SqlToolsSettings oldSettings, EventContext eventContext) { return Task.FromResult(true); } /// /// Build a connection string from a connection details instance /// /// public static string BuildConnectionString(ConnectionDetails connectionDetails) { SqlConnectionStringBuilder connectionBuilder = new SqlConnectionStringBuilder(); connectionBuilder["Data Source"] = connectionDetails.ServerName; connectionBuilder["User Id"] = connectionDetails.UserName; connectionBuilder["Password"] = connectionDetails.Password; // Check for any optional parameters if (!string.IsNullOrEmpty(connectionDetails.DatabaseName)) { connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName; } if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType)) { switch(connectionDetails.AuthenticationType) { case "Integrated": connectionBuilder.IntegratedSecurity = true; break; case "SqlLogin": break; default: throw new ArgumentException(string.Format("Invalid value \"{0}\" for AuthenticationType. Valid values are \"Integrated\" and \"SqlLogin\".", connectionDetails.AuthenticationType)); } } if (connectionDetails.Encrypt.HasValue) { connectionBuilder.Encrypt = connectionDetails.Encrypt.Value; } if (connectionDetails.TrustServerCertificate.HasValue) { connectionBuilder.TrustServerCertificate = connectionDetails.TrustServerCertificate.Value; } if (connectionDetails.PersistSecurityInfo.HasValue) { connectionBuilder.PersistSecurityInfo = connectionDetails.PersistSecurityInfo.Value; } if (connectionDetails.ConnectTimeout.HasValue) { connectionBuilder.ConnectTimeout = connectionDetails.ConnectTimeout.Value; } if (connectionDetails.ConnectRetryCount.HasValue) { connectionBuilder.ConnectRetryCount = connectionDetails.ConnectRetryCount.Value; } if (connectionDetails.ConnectRetryInterval.HasValue) { connectionBuilder.ConnectRetryInterval = connectionDetails.ConnectRetryInterval.Value; } if (!string.IsNullOrEmpty(connectionDetails.ApplicationName)) { connectionBuilder.ApplicationName = connectionDetails.ApplicationName; } if (!string.IsNullOrEmpty(connectionDetails.WorkstationId)) { connectionBuilder.WorkstationID = connectionDetails.WorkstationId; } if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent)) { ApplicationIntent intent; switch (connectionDetails.ApplicationIntent) { case "ReadOnly": intent = ApplicationIntent.ReadOnly; break; case "ReadWrite": intent = ApplicationIntent.ReadWrite; break; default: throw new ArgumentException(string.Format("Invalid value \"{0}\" for ApplicationIntent. Valid values are \"ReadWrite\" and \"ReadOnly\".", connectionDetails.ApplicationIntent)); } connectionBuilder.ApplicationIntent = intent; } if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage)) { connectionBuilder.CurrentLanguage = connectionDetails.CurrentLanguage; } if (connectionDetails.Pooling.HasValue) { connectionBuilder.Pooling = connectionDetails.Pooling.Value; } if (connectionDetails.MaxPoolSize.HasValue) { connectionBuilder.MaxPoolSize = connectionDetails.MaxPoolSize.Value; } if (connectionDetails.MinPoolSize.HasValue) { connectionBuilder.MinPoolSize = connectionDetails.MinPoolSize.Value; } if (connectionDetails.LoadBalanceTimeout.HasValue) { connectionBuilder.LoadBalanceTimeout = connectionDetails.LoadBalanceTimeout.Value; } if (connectionDetails.Replication.HasValue) { connectionBuilder.Replication = connectionDetails.Replication.Value; } if (!string.IsNullOrEmpty(connectionDetails.AttachDbFilename)) { connectionBuilder.AttachDBFilename = connectionDetails.AttachDbFilename; } if (!string.IsNullOrEmpty(connectionDetails.FailoverPartner)) { connectionBuilder.FailoverPartner = connectionDetails.FailoverPartner; } if (connectionDetails.MultiSubnetFailover.HasValue) { connectionBuilder.MultiSubnetFailover = connectionDetails.MultiSubnetFailover.Value; } if (connectionDetails.MultipleActiveResultSets.HasValue) { connectionBuilder.MultipleActiveResultSets = connectionDetails.MultipleActiveResultSets.Value; } if (connectionDetails.PacketSize.HasValue) { connectionBuilder.PacketSize = connectionDetails.PacketSize.Value; } if (!string.IsNullOrEmpty(connectionDetails.TypeSystemVersion)) { connectionBuilder.TypeSystemVersion = connectionDetails.TypeSystemVersion; } return connectionBuilder.ToString(); } /// /// Change the database context of a connection. /// /// URI of the owner of the connection /// Name of the database to change the connection to public void ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName) { ConnectionInfo info; if (TryFindConnection(ownerUri, out info)) { try { if (info.SqlConnection.State == ConnectionState.Open) { info.SqlConnection.ChangeDatabase(newDatabaseName); } info.ConnectionDetails.DatabaseName = newDatabaseName; // Fire a connection changed event ConnectionChangedParams parameters = new ConnectionChangedParams(); ConnectionSummary summary = (ConnectionSummary)(info.ConnectionDetails); parameters.Connection = summary.Clone(); parameters.OwnerUri = ownerUri; ServiceHost.SendEvent(ConnectionChangedNotification.Type, parameters); } catch (Exception e) { Logger.Write( LogLevel.Error, string.Format( "Exception caught while trying to change database context to [{0}] for OwnerUri [{1}]. Exception:{2}", newDatabaseName, ownerUri, e.ToString()) ); } } } } }