diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index 42e3c3c3..9c45ec30 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -1,1053 +1,1076 @@
-//
-// Copyright (c) Microsoft. All rights reserved.
-// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-//
-
-using System;
-using System.Collections.Concurrent;
-using System.Collections.Generic;
-using System.Data;
-using System.Data.Common;
-using System.Data.SqlClient;
-using System.Linq;
-using System.Threading;
-using System.Threading.Tasks;
-using Microsoft.SqlTools.Hosting.Protocol;
-using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
-using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
-using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
-using Microsoft.SqlTools.ServiceLayer.SqlContext;
-using Microsoft.SqlTools.ServiceLayer.Workspace;
-using Microsoft.SqlServer.Management.Common;
-using Microsoft.SqlTools.Utility;
-
-namespace Microsoft.SqlTools.ServiceLayer.Connection
-{
- ///
- /// Main class for the Connection Management services
- ///
- public class ConnectionService
- {
- ///
- /// Singleton service instance
- ///
- private static readonly Lazy instance
- = new Lazy(() => new ConnectionService());
-
- ///
- /// Gets the singleton service instance
- ///
- public static ConnectionService Instance
- {
- get
- {
- return instance.Value;
- }
- }
-
- ///
- /// The SQL connection factory object
- ///
- private ISqlConnectionFactory connectionFactory;
-
- private readonly Dictionary ownerToConnectionMap = new Dictionary();
-
- ///
- /// A map containing all CancellationTokenSource objects that are associated with a given URI/ConnectionType pair.
- /// Entries in this map correspond to DbConnection instances that are in the process of connecting.
- ///
- private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap =
- new ConcurrentDictionary();
-
- private readonly object cancellationTokenSourceLock = new object();
-
- ///
- /// Map from script URIs to ConnectionInfo objects
- /// This is internal for testing access only
- ///
- internal Dictionary OwnerToConnectionMap
- {
- get
- {
- return this.ownerToConnectionMap;
- }
- }
-
- ///
- /// Service host object for sending/receiving requests/events.
- /// Internal for testing purposes.
- ///
- internal IProtocolEndpoint ServiceHost
- {
- get;
- set;
- }
-
- ///
- /// Default constructor should be private since it's a singleton class, but we need a constructor
- /// for use in unit test mocking.
- ///
- public ConnectionService()
- {
- }
-
- ///
- /// Callback for onconnection handler
- ///
- ///
- public delegate Task OnConnectionHandler(ConnectionInfo info);
-
- ///
- /// Callback for ondisconnect handler
- ///
- public delegate Task OnDisconnectHandler(ConnectionSummary summary, string ownerUri);
-
- ///
- /// List of onconnection handlers
- ///
- private readonly List onConnectionActivities = new List();
-
- ///
- /// List of ondisconnect handlers
- ///
- private readonly List onDisconnectActivities = new List();
-
- ///
- /// Gets the SQL connection factory instance
- ///
- public ISqlConnectionFactory ConnectionFactory
- {
- get
- {
- if (this.connectionFactory == null)
- {
- this.connectionFactory = new SqlConnectionFactory();
- }
- return this.connectionFactory;
- }
-
- internal set { this.connectionFactory = value; }
- }
-
- ///
- /// Test constructor that injects dependency interfaces
- ///
- ///
- public ConnectionService(ISqlConnectionFactory testFactory)
- {
- this.connectionFactory = testFactory;
- }
-
- // Attempts to link a URI to an actively used connection for this URI
- public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo)
- {
- return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo);
- }
-
- ///
- /// Validates the given ConnectParams object.
- ///
- /// The params to validate
- /// A ConnectionCompleteParams object upon validation error,
- /// null upon validation success
- public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams)
- {
- string paramValidationErrorMessage;
- if (connectionParams == null)
- {
- return new ConnectionCompleteParams
- {
- Messages = SR.ConnectionServiceConnectErrorNullParams
- };
- }
- if (!connectionParams.IsValid(out paramValidationErrorMessage))
- {
- return new ConnectionCompleteParams
- {
- OwnerUri = connectionParams.OwnerUri,
- Messages = paramValidationErrorMessage
- };
- }
-
- // return null upon success
- return null;
- }
-
- ///
- /// Open a connection with the specified ConnectParams
- ///
- public virtual async Task Connect(ConnectParams connectionParams)
- {
- // Validate parameters
- ConnectionCompleteParams validationResults = ValidateConnectParams(connectionParams);
- if (validationResults != null)
- {
- return validationResults;
- }
-
- // If there is no ConnectionInfo in the map, create a new ConnectionInfo,
- // but wait until later when we are connected to add it to the map.
- ConnectionInfo connectionInfo;
- bool connectionChanged = false;
- if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo))
- {
- connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
- }
- else if (IsConnectionChanged(connectionParams, connectionInfo))
- {
- // We are actively changing the connection information for this connection. We must disconnect
- // all active connections, since it represents a full context change
- connectionChanged = true;
- }
-
- DisconnectExistingConnectionIfNeeded(connectionParams, connectionInfo, disconnectAll: connectionChanged);
-
- if (connectionChanged)
- {
- connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
- }
-
- // Try to open a connection with the given ConnectParams
- ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams);
- if (response != null)
- {
- return response;
- }
-
- // If this is the first connection for this URI, add the ConnectionInfo to the map
- bool addToMap = connectionChanged || !ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri);
- if (addToMap)
- {
- ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
- }
-
- // Return information about the connected SQL Server instance
- ConnectionCompleteParams completeParams = GetConnectionCompleteParams(connectionParams.Type, connectionInfo);
- // Invoke callback notifications
- InvokeOnConnectionActivities(connectionInfo, connectionParams);
-
- return completeParams;
- }
-
- private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo)
- {
- if (connectionInfo.HasConnectionType(connectionParams.Type)
- && !connectionInfo.ConnectionDetails.IsComparableTo(connectionParams.Connection))
- {
- return true;
- }
- return false;
- }
-
- private bool IsDefaultConnectionType(string connectionType)
- {
- return string.IsNullOrEmpty(connectionType) || ConnectionType.Default.Equals(connectionType, StringComparison.CurrentCultureIgnoreCase);
- }
-
- private void DisconnectExistingConnectionIfNeeded(ConnectParams connectionParams, ConnectionInfo connectionInfo, bool disconnectAll)
- {
- // Resolve if it is an existing connection
- // Disconnect active connection if the URI is already connected for this connection type
- DbConnection existingConnection;
- if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection))
- {
- var disconnectParams = new DisconnectParams()
- {
- OwnerUri = connectionParams.OwnerUri,
- Type = disconnectAll ? null : connectionParams.Type
- };
- Disconnect(disconnectParams);
- }
- }
-
- ///
- /// Creates a ConnectionCompleteParams as a response to a successful connection.
- /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo.
- ///
- /// A ConnectionCompleteParams in response to the successful connection
- private ConnectionCompleteParams GetConnectionCompleteParams(string connectionType, ConnectionInfo connectionInfo)
- {
- ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionType };
-
- try
- {
- DbConnection connection;
- connectionInfo.TryGetConnection(connectionType, out connection);
-
- // Update with the actual database name in connectionInfo and result
- // Doing this here as we know the connection is open - expect to do this only on connecting
- connectionInfo.ConnectionDetails.DatabaseName = connection.Database;
- response.ConnectionSummary = new ConnectionSummary
- {
- ServerName = connectionInfo.ConnectionDetails.ServerName,
- DatabaseName = connectionInfo.ConnectionDetails.DatabaseName,
- UserName = connectionInfo.ConnectionDetails.UserName,
- };
-
- response.ConnectionId = connectionInfo.ConnectionId.ToString();
-
- var reliableConnection = connection as ReliableSqlConnection;
- DbConnection underlyingConnection = reliableConnection != null
- ? reliableConnection.GetUnderlyingConnection()
- : connection;
-
- ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(underlyingConnection);
- response.ServerInfo = new ServerInfo
- {
- ServerMajorVersion = serverInfo.ServerMajorVersion,
- ServerMinorVersion = serverInfo.ServerMinorVersion,
- ServerReleaseVersion = serverInfo.ServerReleaseVersion,
- EngineEditionId = serverInfo.EngineEditionId,
- ServerVersion = serverInfo.ServerVersion,
- ServerLevel = serverInfo.ServerLevel,
- ServerEdition = serverInfo.ServerEdition,
- IsCloud = serverInfo.IsCloud,
- AzureVersion = serverInfo.AzureVersion,
- OsVersion = serverInfo.OsVersion
- };
- connectionInfo.IsAzure = serverInfo.IsCloud;
- connectionInfo.MajorVersion = serverInfo.ServerMajorVersion;
- connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
- }
- catch (Exception ex)
- {
- response.Messages = ex.ToString();
- }
-
- return response;
- }
-
- ///
- /// Tries to create and open a connection with the given ConnectParams.
- ///
- /// null upon success, a ConnectionCompleteParams detailing the error upon failure
- private async Task TryOpenConnection(ConnectionInfo connectionInfo, ConnectParams connectionParams)
- {
- CancellationTokenSource source = null;
- DbConnection connection = null;
- CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type };
- ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type };
-
- try
- {
- // build the connection string from the input parameters
- string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
-
- // create a sql connection instance
- connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
- connectionInfo.AddConnection(connectionParams.Type, connection);
-
- // Add a cancellation token source so that the connection OpenAsync() can be cancelled
- source = new CancellationTokenSource();
- // Locking here to perform two operations as one atomic operation
- lock (cancellationTokenSourceLock)
- {
- // If the URI is currently connecting from a different request, cancel it before we try to connect
- CancellationTokenSource currentSource;
- if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource))
- {
- currentSource.Cancel();
- }
- cancelTupleToCancellationTokenSourceMap[cancelKey] = source;
- }
-
- // Open the connection
- await connection.OpenAsync(source.Token);
- }
- catch (SqlException ex)
- {
- response.ErrorNumber = ex.Number;
- response.ErrorMessage = ex.Message;
- response.Messages = ex.ToString();
- return response;
- }
- catch (OperationCanceledException)
- {
- // OpenAsync was cancelled
- response.Messages = SR.ConnectionServiceConnectionCanceled;
- return response;
- }
- catch (Exception ex)
- {
- response.ErrorMessage = ex.Message;
- response.Messages = ex.ToString();
- return response;
- }
- finally
- {
- // Remove our cancellation token from the map since we're no longer connecting
- // Using a lock here to perform two operations as one atomic operation
- lock (cancellationTokenSourceLock)
- {
- // Only remove the token from the map if it is the same one created by this request
- CancellationTokenSource sourceValue;
- if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out sourceValue) && sourceValue == source)
- {
- cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue);
- }
- source?.Dispose();
- }
- }
-
- // Return null upon success
- return null;
- }
-
- ///
- /// Gets the existing connection with the given URI and connection type string. If none exists,
- /// creates a new connection. This cannot be used to create a default connection or to create a
- /// connection if a default connection does not exist.
- ///
- /// A DB connection for the connection type requested
- public async Task GetOrOpenConnection(string ownerUri, string connectionType)
- {
- Validate.IsNotNullOrEmptyString(nameof(ownerUri), ownerUri);
- Validate.IsNotNullOrEmptyString(nameof(connectionType), connectionType);
-
- // Try to get the ConnectionInfo, if it exists
- ConnectionInfo connectionInfo;
- if (!ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo))
- {
- throw new ArgumentOutOfRangeException(SR.ConnectionServiceListDbErrorNotConnected(ownerUri));
- }
-
- // Make sure a default connection exists
- DbConnection defaultConnection;
- if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection))
- {
- throw new InvalidOperationException(SR.ConnectionServiceDbErrorDefaultNotConnected(ownerUri));
- }
-
- // Try to get the DbConnection
- DbConnection connection;
- if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType)
- {
- // If the DbConnection does not exist and is not the default connection, create one.
- // We can't create the default (initial) connection here because we won't have a ConnectionDetails
- // if Connect() has not yet been called.
- ConnectParams connectParams = new ConnectParams
- {
- OwnerUri = ownerUri,
- Connection = connectionInfo.ConnectionDetails,
- Type = connectionType
- };
- await Connect(connectParams);
- connectionInfo.TryGetConnection(connectionType, out connection);
- }
-
- return connection;
- }
-
- ///
- /// Cancel a connection that is in the process of opening.
- ///
- public bool CancelConnect(CancelConnectParams cancelParams)
- {
- // Validate parameters
- if (cancelParams == null || string.IsNullOrEmpty(cancelParams.OwnerUri))
- {
- return false;
- }
-
- CancelTokenKey cancelKey = new CancelTokenKey
- {
- OwnerUri = cancelParams.OwnerUri,
- Type = cancelParams.Type
- };
-
- // Cancel any current connection attempts for this URI
- CancellationTokenSource source;
- if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out source))
- {
- try
- {
- source.Cancel();
- return true;
- }
- catch
- {
- return false;
- }
- }
-
- return false;
- }
-
- ///
- /// Close a connection with the specified connection details.
- ///
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Data;
+using System.Data.Common;
+using System.Data.SqlClient;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
+using Microsoft.SqlTools.ServiceLayer.LanguageServices.Contracts;
+using Microsoft.SqlTools.ServiceLayer.SqlContext;
+using Microsoft.SqlTools.ServiceLayer.Workspace;
+using Microsoft.SqlServer.Management.Common;
+using Microsoft.SqlTools.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Connection
+{
+ ///
+ /// Main class for the Connection Management services
+ ///
+ public class ConnectionService
+ {
+ ///
+ /// Singleton service instance
+ ///
+ private static readonly Lazy instance
+ = new Lazy(() => new ConnectionService());
+
+ ///
+ /// Gets the singleton service instance
+ ///
+ public static ConnectionService Instance
+ {
+ get
+ {
+ return instance.Value;
+ }
+ }
+
+ ///
+ /// The SQL connection factory object
+ ///
+ private ISqlConnectionFactory connectionFactory;
+
+ private readonly Dictionary ownerToConnectionMap = new Dictionary();
+
+ ///
+ /// A map containing all CancellationTokenSource objects that are associated with a given URI/ConnectionType pair.
+ /// Entries in this map correspond to DbConnection instances that are in the process of connecting.
+ ///
+ private readonly ConcurrentDictionary cancelTupleToCancellationTokenSourceMap =
+ new ConcurrentDictionary();
+
+ private readonly object cancellationTokenSourceLock = new object();
+
+ ///
+ /// Map from script URIs to ConnectionInfo objects
+ /// This is internal for testing access only
+ ///
+ internal Dictionary OwnerToConnectionMap
+ {
+ get
+ {
+ return this.ownerToConnectionMap;
+ }
+ }
+
+ ///
+ /// Service host object for sending/receiving requests/events.
+ /// Internal for testing purposes.
+ ///
+ internal IProtocolEndpoint ServiceHost
+ {
+ get;
+ set;
+ }
+
+ ///
+ /// Default constructor should be private since it's a singleton class, but we need a constructor
+ /// for use in unit test mocking.
+ ///
+ public ConnectionService()
+ {
+ }
+
+ ///
+ /// Callback for onconnection handler
+ ///
+ ///
+ public delegate Task OnConnectionHandler(ConnectionInfo info);
+
+ ///
+ /// Callback for ondisconnect handler
+ ///
+ public delegate Task OnDisconnectHandler(ConnectionSummary summary, string ownerUri);
+
+ ///
+ /// List of onconnection handlers
+ ///
+ private readonly List onConnectionActivities = new List();
+
+ ///
+ /// List of ondisconnect handlers
+ ///
+ private readonly List onDisconnectActivities = new List();
+
+ ///
+ /// Gets the SQL connection factory instance
+ ///
+ public ISqlConnectionFactory ConnectionFactory
+ {
+ get
+ {
+ if (this.connectionFactory == null)
+ {
+ this.connectionFactory = new SqlConnectionFactory();
+ }
+ return this.connectionFactory;
+ }
+
+ internal set { this.connectionFactory = value; }
+ }
+
+ ///
+ /// Test constructor that injects dependency interfaces
+ ///
+ ///
+ public ConnectionService(ISqlConnectionFactory testFactory)
+ {
+ this.connectionFactory = testFactory;
+ }
+
+ // Attempts to link a URI to an actively used connection for this URI
+ public virtual bool TryFindConnection(string ownerUri, out ConnectionInfo connectionInfo)
+ {
+ return this.ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo);
+ }
+
+ ///
+ /// Validates the given ConnectParams object.
+ ///
+ /// The params to validate
+ /// A ConnectionCompleteParams object upon validation error,
+ /// null upon validation success
+ public ConnectionCompleteParams ValidateConnectParams(ConnectParams connectionParams)
+ {
+ string paramValidationErrorMessage;
+ if (connectionParams == null)
+ {
+ return new ConnectionCompleteParams
+ {
+ Messages = SR.ConnectionServiceConnectErrorNullParams
+ };
+ }
+ if (!connectionParams.IsValid(out paramValidationErrorMessage))
+ {
+ return new ConnectionCompleteParams
+ {
+ OwnerUri = connectionParams.OwnerUri,
+ Messages = paramValidationErrorMessage
+ };
+ }
+
+ // return null upon success
+ return null;
+ }
+
+ ///
+ /// Open a connection with the specified ConnectParams
+ ///
+ public virtual async Task Connect(ConnectParams connectionParams)
+ {
+ // Validate parameters
+ ConnectionCompleteParams validationResults = ValidateConnectParams(connectionParams);
+ if (validationResults != null)
+ {
+ return validationResults;
+ }
+
+ // If there is no ConnectionInfo in the map, create a new ConnectionInfo,
+ // but wait until later when we are connected to add it to the map.
+ ConnectionInfo connectionInfo;
+ bool connectionChanged = false;
+ if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo))
+ {
+ connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
+ }
+ else if (IsConnectionChanged(connectionParams, connectionInfo))
+ {
+ // We are actively changing the connection information for this connection. We must disconnect
+ // all active connections, since it represents a full context change
+ connectionChanged = true;
+ }
+
+ DisconnectExistingConnectionIfNeeded(connectionParams, connectionInfo, disconnectAll: connectionChanged);
+
+ if (connectionChanged)
+ {
+ connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
+ }
+
+ // Try to open a connection with the given ConnectParams
+ ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams);
+ if (response != null)
+ {
+ return response;
+ }
+
+ // If this is the first connection for this URI, add the ConnectionInfo to the map
+ bool addToMap = connectionChanged || !ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri);
+ if (addToMap)
+ {
+ ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
+ }
+
+ // Return information about the connected SQL Server instance
+ ConnectionCompleteParams completeParams = GetConnectionCompleteParams(connectionParams.Type, connectionInfo);
+ // Invoke callback notifications
+ InvokeOnConnectionActivities(connectionInfo, connectionParams);
+
+ return completeParams;
+ }
+
+ private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo)
+ {
+ if (connectionInfo.HasConnectionType(connectionParams.Type)
+ && !connectionInfo.ConnectionDetails.IsComparableTo(connectionParams.Connection))
+ {
+ return true;
+ }
+ return false;
+ }
+
+ private bool IsDefaultConnectionType(string connectionType)
+ {
+ return string.IsNullOrEmpty(connectionType) || ConnectionType.Default.Equals(connectionType, StringComparison.CurrentCultureIgnoreCase);
+ }
+
+ private void DisconnectExistingConnectionIfNeeded(ConnectParams connectionParams, ConnectionInfo connectionInfo, bool disconnectAll)
+ {
+ // Resolve if it is an existing connection
+ // Disconnect active connection if the URI is already connected for this connection type
+ DbConnection existingConnection;
+ if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection))
+ {
+ var disconnectParams = new DisconnectParams()
+ {
+ OwnerUri = connectionParams.OwnerUri,
+ Type = disconnectAll ? null : connectionParams.Type
+ };
+ Disconnect(disconnectParams);
+ }
+ }
+
+ ///
+ /// Creates a ConnectionCompleteParams as a response to a successful connection.
+ /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo.
+ ///
+ /// A ConnectionCompleteParams in response to the successful connection
+ private ConnectionCompleteParams GetConnectionCompleteParams(string connectionType, ConnectionInfo connectionInfo)
+ {
+ ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionType };
+
+ try
+ {
+ DbConnection connection;
+ connectionInfo.TryGetConnection(connectionType, out connection);
+
+ // Update with the actual database name in connectionInfo and result
+ // Doing this here as we know the connection is open - expect to do this only on connecting
+ connectionInfo.ConnectionDetails.DatabaseName = connection.Database;
+ if (!string.IsNullOrEmpty(connectionInfo.ConnectionDetails.ConnectionString))
+ {
+ // If the connection was set up with a connection string, use the connection string to get the details
+ var connectionString = new SqlConnectionStringBuilder(connection.ConnectionString);
+ response.ConnectionSummary = new ConnectionSummary
+ {
+ ServerName = connectionString.DataSource,
+ DatabaseName = connectionString.InitialCatalog,
+ UserName = connectionString.UserID
+ };
+ }
+ else
+ {
+ response.ConnectionSummary = new ConnectionSummary
+ {
+ ServerName = connectionInfo.ConnectionDetails.ServerName,
+ DatabaseName = connectionInfo.ConnectionDetails.DatabaseName,
+ UserName = connectionInfo.ConnectionDetails.UserName
+ };
+ }
+
+ response.ConnectionId = connectionInfo.ConnectionId.ToString();
+
+ var reliableConnection = connection as ReliableSqlConnection;
+ DbConnection underlyingConnection = reliableConnection != null
+ ? reliableConnection.GetUnderlyingConnection()
+ : connection;
+
+ ReliableConnectionHelper.ServerInfo serverInfo = ReliableConnectionHelper.GetServerVersion(underlyingConnection);
+ response.ServerInfo = new ServerInfo
+ {
+ ServerMajorVersion = serverInfo.ServerMajorVersion,
+ ServerMinorVersion = serverInfo.ServerMinorVersion,
+ ServerReleaseVersion = serverInfo.ServerReleaseVersion,
+ EngineEditionId = serverInfo.EngineEditionId,
+ ServerVersion = serverInfo.ServerVersion,
+ ServerLevel = serverInfo.ServerLevel,
+ ServerEdition = serverInfo.ServerEdition,
+ IsCloud = serverInfo.IsCloud,
+ AzureVersion = serverInfo.AzureVersion,
+ OsVersion = serverInfo.OsVersion
+ };
+ connectionInfo.IsAzure = serverInfo.IsCloud;
+ connectionInfo.MajorVersion = serverInfo.ServerMajorVersion;
+ connectionInfo.IsSqlDW = (serverInfo.EngineEditionId == (int)DatabaseEngineEdition.SqlDataWarehouse);
+ }
+ catch (Exception ex)
+ {
+ response.Messages = ex.ToString();
+ }
+
+ return response;
+ }
+
+ ///
+ /// Tries to create and open a connection with the given ConnectParams.
+ ///
+ /// null upon success, a ConnectionCompleteParams detailing the error upon failure
+ private async Task TryOpenConnection(ConnectionInfo connectionInfo, ConnectParams connectionParams)
+ {
+ CancellationTokenSource source = null;
+ DbConnection connection = null;
+ CancelTokenKey cancelKey = new CancelTokenKey { OwnerUri = connectionParams.OwnerUri, Type = connectionParams.Type };
+ ConnectionCompleteParams response = new ConnectionCompleteParams { OwnerUri = connectionInfo.OwnerUri, Type = connectionParams.Type };
+
+ try
+ {
+ // build the connection string from the input parameters
+ string connectionString = BuildConnectionString(connectionInfo.ConnectionDetails);
+
+ // create a sql connection instance
+ connection = connectionInfo.Factory.CreateSqlConnection(connectionString);
+ connectionInfo.AddConnection(connectionParams.Type, connection);
+
+ // Add a cancellation token source so that the connection OpenAsync() can be cancelled
+ source = new CancellationTokenSource();
+ // Locking here to perform two operations as one atomic operation
+ lock (cancellationTokenSourceLock)
+ {
+ // If the URI is currently connecting from a different request, cancel it before we try to connect
+ CancellationTokenSource currentSource;
+ if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource))
+ {
+ currentSource.Cancel();
+ }
+ cancelTupleToCancellationTokenSourceMap[cancelKey] = source;
+ }
+
+ // Open the connection
+ await connection.OpenAsync(source.Token);
+ }
+ catch (SqlException ex)
+ {
+ response.ErrorNumber = ex.Number;
+ response.ErrorMessage = ex.Message;
+ response.Messages = ex.ToString();
+ return response;
+ }
+ catch (OperationCanceledException)
+ {
+ // OpenAsync was cancelled
+ response.Messages = SR.ConnectionServiceConnectionCanceled;
+ return response;
+ }
+ catch (Exception ex)
+ {
+ response.ErrorMessage = ex.Message;
+ response.Messages = ex.ToString();
+ return response;
+ }
+ finally
+ {
+ // Remove our cancellation token from the map since we're no longer connecting
+ // Using a lock here to perform two operations as one atomic operation
+ lock (cancellationTokenSourceLock)
+ {
+ // Only remove the token from the map if it is the same one created by this request
+ CancellationTokenSource sourceValue;
+ if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out sourceValue) && sourceValue == source)
+ {
+ cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue);
+ }
+ source?.Dispose();
+ }
+ }
+
+ // Return null upon success
+ return null;
+ }
+
+ ///
+ /// Gets the existing connection with the given URI and connection type string. If none exists,
+ /// creates a new connection. This cannot be used to create a default connection or to create a
+ /// connection if a default connection does not exist.
+ ///
+ /// A DB connection for the connection type requested
+ public async Task GetOrOpenConnection(string ownerUri, string connectionType)
+ {
+ Validate.IsNotNullOrEmptyString(nameof(ownerUri), ownerUri);
+ Validate.IsNotNullOrEmptyString(nameof(connectionType), connectionType);
+
+ // Try to get the ConnectionInfo, if it exists
+ ConnectionInfo connectionInfo;
+ if (!ownerToConnectionMap.TryGetValue(ownerUri, out connectionInfo))
+ {
+ throw new ArgumentOutOfRangeException(SR.ConnectionServiceListDbErrorNotConnected(ownerUri));
+ }
+
+ // Make sure a default connection exists
+ DbConnection defaultConnection;
+ if (!connectionInfo.TryGetConnection(ConnectionType.Default, out defaultConnection))
+ {
+ throw new InvalidOperationException(SR.ConnectionServiceDbErrorDefaultNotConnected(ownerUri));
+ }
+
+ // Try to get the DbConnection
+ DbConnection connection;
+ if (!connectionInfo.TryGetConnection(connectionType, out connection) && ConnectionType.Default != connectionType)
+ {
+ // If the DbConnection does not exist and is not the default connection, create one.
+ // We can't create the default (initial) connection here because we won't have a ConnectionDetails
+ // if Connect() has not yet been called.
+ ConnectParams connectParams = new ConnectParams
+ {
+ OwnerUri = ownerUri,
+ Connection = connectionInfo.ConnectionDetails,
+ Type = connectionType
+ };
+ await Connect(connectParams);
+ connectionInfo.TryGetConnection(connectionType, out connection);
+ }
+
+ return connection;
+ }
+
+ ///
+ /// Cancel a connection that is in the process of opening.
+ ///
+ public bool CancelConnect(CancelConnectParams cancelParams)
+ {
+ // Validate parameters
+ if (cancelParams == null || string.IsNullOrEmpty(cancelParams.OwnerUri))
+ {
+ return false;
+ }
+
+ CancelTokenKey cancelKey = new CancelTokenKey
+ {
+ OwnerUri = cancelParams.OwnerUri,
+ Type = cancelParams.Type
+ };
+
+ // Cancel any current connection attempts for this URI
+ CancellationTokenSource source;
+ if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out source))
+ {
+ try
+ {
+ source.Cancel();
+ return true;
+ }
+ catch
+ {
+ return false;
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Close a connection with the specified connection details.
+ ///
public virtual bool Disconnect(DisconnectParams disconnectParams)
- {
- // Validate parameters
- if (disconnectParams == null || string.IsNullOrEmpty(disconnectParams.OwnerUri))
- {
- return false;
- }
-
- // Cancel if we are in the middle of connecting
- if (CancelConnections(disconnectParams.OwnerUri, disconnectParams.Type))
- {
- return false;
- }
-
- // Lookup the ConnectionInfo owned by the URI
- ConnectionInfo info;
- if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info))
- {
- return false;
- }
-
- // Call Close() on the connections we want to disconnect
- // If no connections were located, return false
- if (!CloseConnections(info, disconnectParams.Type))
- {
- return false;
- }
-
- // Remove the disconnected connections from the ConnectionInfo map
- if (disconnectParams.Type == null)
- {
- info.RemoveAllConnections();
- }
- else
- {
- info.RemoveConnection(disconnectParams.Type);
- }
-
- // If the ConnectionInfo has no more connections, remove the ConnectionInfo
- if (info.CountConnections == 0)
- {
- ownerToConnectionMap.Remove(disconnectParams.OwnerUri);
- }
-
- // Handle Telemetry disconnect events if we are disconnecting the default connection
- if (disconnectParams.Type == null || disconnectParams.Type == ConnectionType.Default)
- {
- HandleDisconnectTelemetry(info);
- InvokeOnDisconnectionActivities(info);
- }
-
- // Return true upon success
- return true;
- }
-
- ///
- /// Cancel connections associated with the given ownerUri.
- /// If connectionType is not null, cancel the connection with the given connectionType
- /// If connectionType is null, cancel all pending connections associated with ownerUri.
- ///
- /// true if a single pending connection associated with the non-null connectionType was
- /// found and cancelled, false otherwise
- private bool CancelConnections(string ownerUri, string connectionType)
- {
- // Cancel the connection of the given type
- if (connectionType != null)
- {
- // If we are trying to disconnect a specific connection and it was just cancelled,
- // this will return true
- return CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = connectionType });
- }
-
- // Cancel all pending connections
- foreach (var entry in cancelTupleToCancellationTokenSourceMap)
- {
- string entryConnectionUri = entry.Key.OwnerUri;
- string entryConnectionType = entry.Key.Type;
- if (ownerUri.Equals(entryConnectionUri))
- {
- CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = entryConnectionType });
- }
- }
-
- return false;
- }
-
- ///
- /// Closes DbConnections associated with the given ConnectionInfo.
- /// If connectionType is not null, closes the DbConnection with the type given by connectionType.
- /// If connectionType is null, closes all DbConnections.
- ///
- /// true if connections were found and attempted to be closed,
- /// false if no connections were found
- private bool CloseConnections(ConnectionInfo connectionInfo, string connectionType)
- {
- ICollection connectionsToDisconnect = new List();
- if (connectionType == null)
- {
- connectionsToDisconnect = connectionInfo.AllConnections;
- }
- else
- {
- // Make sure there is an existing connection of this type
- DbConnection connection;
- if (!connectionInfo.TryGetConnection(connectionType, out connection))
- {
- return false;
- }
- connectionsToDisconnect.Add(connection);
- }
-
- if (connectionsToDisconnect.Count == 0)
- {
- return false;
- }
-
- foreach (DbConnection connection in connectionsToDisconnect)
- {
- try
- {
- connection.Close();
- }
- catch (Exception)
- {
- // Ignore
- }
- }
-
- return true;
- }
-
- ///
- /// List all databases on the server specified
- ///
- public ListDatabasesResponse ListDatabases(ListDatabasesParams listDatabasesParams)
- {
- // Verify parameters
- var owner = listDatabasesParams.OwnerUri;
- if (string.IsNullOrEmpty(owner))
- {
- throw new ArgumentException(SR.ConnectionServiceListDbErrorNullOwnerUri);
- }
-
- // Use the existing connection as a base for the search
- ConnectionInfo info;
- if (!TryFindConnection(owner, out info))
- {
- throw new Exception(SR.ConnectionServiceListDbErrorNotConnected(owner));
- }
- ConnectionDetails connectionDetails = info.ConnectionDetails.Clone();
-
- // Connect to master and query sys.databases
- connectionDetails.DatabaseName = "master";
- var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails));
- connection.Open();
-
- List results = new List();
- var systemDatabases = new[] {"master", "model", "msdb", "tempdb"};
- using (DbCommand command = connection.CreateCommand())
- {
- command.CommandText = "SELECT name FROM sys.databases ORDER BY name ASC";
- command.CommandTimeout = 15;
- command.CommandType = CommandType.Text;
-
- using (var reader = command.ExecuteReader())
- {
- while (reader.Read())
- {
- results.Add(reader[0].ToString());
- }
- }
- }
-
- // Put system databases at the top of the list
- results =
- results.Where(s => systemDatabases.Any(s.Equals)).Concat(
- results.Where(s => systemDatabases.All(x => !s.Equals(x)))).ToList();
-
- connection.Close();
-
- ListDatabasesResponse response = new ListDatabasesResponse();
- response.DatabaseNames = results.ToArray();
-
- return response;
- }
-
- public void InitializeService(IProtocolEndpoint serviceHost)
- {
- this.ServiceHost = serviceHost;
-
- // Register request and event handlers with the Service Host
- serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest);
- serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest);
- serviceHost.SetRequestHandler(DisconnectRequest.Type, HandleDisconnectRequest);
- serviceHost.SetRequestHandler(ListDatabasesRequest.Type, HandleListDatabasesRequest);
-
- // Register the configuration update handler
- WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification);
- }
-
- ///
- /// Add a new method to be called when the onconnection request is submitted
- ///
- ///
- public void RegisterOnConnectionTask(OnConnectionHandler activity)
- {
- onConnectionActivities.Add(activity);
- }
-
- ///
- /// Add a new method to be called when the ondisconnect request is submitted
- ///
- public void RegisterOnDisconnectTask(OnDisconnectHandler activity)
- {
- onDisconnectActivities.Add(activity);
- }
-
- ///
- /// Handle new connection requests
- ///
- ///
- ///
- ///
- protected async Task HandleConnectRequest(
- ConnectParams connectParams,
- RequestContext requestContext)
- {
- Logger.Write(LogLevel.Verbose, "HandleConnectRequest");
-
- try
- {
- RunConnectRequestHandlerTask(connectParams);
- await requestContext.SendResult(true);
- }
- catch
- {
- await requestContext.SendResult(false);
- }
- }
-
- private void RunConnectRequestHandlerTask(ConnectParams connectParams)
- {
- // create a task to connect asynchronously so that other requests are not blocked in the meantime
- Task.Run(async () =>
- {
- try
- {
- // result is null if the ConnectParams was successfully validated
- ConnectionCompleteParams result = ValidateConnectParams(connectParams);
- if (result != null)
- {
- await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
- return;
- }
-
- // open connection based on request details
- result = await Connect(connectParams);
- await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
- }
- catch (Exception ex)
- {
- ConnectionCompleteParams result = new ConnectionCompleteParams()
- {
- Messages = ex.ToString()
- };
- await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
- }
- });
- }
-
- ///
- /// Handle cancel connect requests
- ///
- protected async Task HandleCancelConnectRequest(
- CancelConnectParams cancelParams,
- RequestContext requestContext)
- {
- Logger.Write(LogLevel.Verbose, "HandleCancelConnectRequest");
-
- try
- {
- bool result = CancelConnect(cancelParams);
- await requestContext.SendResult(result);
- }
- catch(Exception ex)
- {
- await requestContext.SendError(ex.ToString());
- }
- }
-
- ///
- /// Handle disconnect requests
- ///
- protected async Task HandleDisconnectRequest(
- DisconnectParams disconnectParams,
- RequestContext requestContext)
- {
- Logger.Write(LogLevel.Verbose, "HandleDisconnectRequest");
-
- try
- {
- bool result = Instance.Disconnect(disconnectParams);
- await requestContext.SendResult(result);
- }
- catch(Exception ex)
- {
- await requestContext.SendError(ex.ToString());
- }
-
- }
-
- ///
- /// Handle requests to list databases on the current server
- ///
- protected async Task HandleListDatabasesRequest(
- ListDatabasesParams listDatabasesParams,
- RequestContext requestContext)
- {
- Logger.Write(LogLevel.Verbose, "ListDatabasesRequest");
-
- try
- {
- ListDatabasesResponse result = Instance.ListDatabases(listDatabasesParams);
- await requestContext.SendResult(result);
- }
- catch(Exception ex)
- {
- await requestContext.SendError(ex.ToString());
- }
- }
-
- public Task HandleDidChangeConfigurationNotification(
- SqlToolsSettings newSettings,
- SqlToolsSettings oldSettings,
- EventContext eventContext)
- {
- return Task.FromResult(true);
- }
-
- ///
- /// Build a connection string from a connection details instance
- ///
- ///
- public static string BuildConnectionString(ConnectionDetails connectionDetails)
- {
- SqlConnectionStringBuilder connectionBuilder = new SqlConnectionStringBuilder
- {
- ["Data Source"] = connectionDetails.ServerName,
- ["User Id"] = connectionDetails.UserName,
- ["Password"] = connectionDetails.Password
- };
-
- // Check for any optional parameters
- if (!string.IsNullOrEmpty(connectionDetails.DatabaseName))
- {
- connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
- }
- if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType))
- {
- switch(connectionDetails.AuthenticationType)
- {
- case "Integrated":
- connectionBuilder.IntegratedSecurity = true;
- break;
- case "SqlLogin":
- break;
- default:
- throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
- }
- }
- if (connectionDetails.Encrypt.HasValue)
- {
- connectionBuilder.Encrypt = connectionDetails.Encrypt.Value;
- }
- if (connectionDetails.TrustServerCertificate.HasValue)
- {
- connectionBuilder.TrustServerCertificate = connectionDetails.TrustServerCertificate.Value;
- }
- if (connectionDetails.PersistSecurityInfo.HasValue)
- {
- connectionBuilder.PersistSecurityInfo = connectionDetails.PersistSecurityInfo.Value;
- }
- if (connectionDetails.ConnectTimeout.HasValue)
- {
- connectionBuilder.ConnectTimeout = connectionDetails.ConnectTimeout.Value;
- }
- if (connectionDetails.ConnectRetryCount.HasValue)
- {
- connectionBuilder.ConnectRetryCount = connectionDetails.ConnectRetryCount.Value;
- }
- if (connectionDetails.ConnectRetryInterval.HasValue)
- {
- connectionBuilder.ConnectRetryInterval = connectionDetails.ConnectRetryInterval.Value;
- }
- if (!string.IsNullOrEmpty(connectionDetails.ApplicationName))
- {
- connectionBuilder.ApplicationName = connectionDetails.ApplicationName;
- }
- if (!string.IsNullOrEmpty(connectionDetails.WorkstationId))
- {
- connectionBuilder.WorkstationID = connectionDetails.WorkstationId;
- }
- if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent))
- {
- ApplicationIntent intent;
- switch (connectionDetails.ApplicationIntent)
- {
- case "ReadOnly":
- intent = ApplicationIntent.ReadOnly;
- break;
- case "ReadWrite":
- intent = ApplicationIntent.ReadWrite;
- break;
- default:
- throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
- }
- connectionBuilder.ApplicationIntent = intent;
- }
- if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage))
- {
- connectionBuilder.CurrentLanguage = connectionDetails.CurrentLanguage;
- }
- if (connectionDetails.Pooling.HasValue)
- {
- connectionBuilder.Pooling = connectionDetails.Pooling.Value;
- }
- if (connectionDetails.MaxPoolSize.HasValue)
- {
- connectionBuilder.MaxPoolSize = connectionDetails.MaxPoolSize.Value;
- }
- if (connectionDetails.MinPoolSize.HasValue)
- {
- connectionBuilder.MinPoolSize = connectionDetails.MinPoolSize.Value;
- }
- if (connectionDetails.LoadBalanceTimeout.HasValue)
- {
- connectionBuilder.LoadBalanceTimeout = connectionDetails.LoadBalanceTimeout.Value;
- }
- if (connectionDetails.Replication.HasValue)
- {
- connectionBuilder.Replication = connectionDetails.Replication.Value;
- }
- if (!string.IsNullOrEmpty(connectionDetails.AttachDbFilename))
- {
- connectionBuilder.AttachDBFilename = connectionDetails.AttachDbFilename;
- }
- if (!string.IsNullOrEmpty(connectionDetails.FailoverPartner))
- {
- connectionBuilder.FailoverPartner = connectionDetails.FailoverPartner;
- }
- if (connectionDetails.MultiSubnetFailover.HasValue)
- {
- connectionBuilder.MultiSubnetFailover = connectionDetails.MultiSubnetFailover.Value;
- }
- if (connectionDetails.MultipleActiveResultSets.HasValue)
- {
- connectionBuilder.MultipleActiveResultSets = connectionDetails.MultipleActiveResultSets.Value;
- }
- if (connectionDetails.PacketSize.HasValue)
- {
- connectionBuilder.PacketSize = connectionDetails.PacketSize.Value;
- }
- if (!string.IsNullOrEmpty(connectionDetails.TypeSystemVersion))
- {
- connectionBuilder.TypeSystemVersion = connectionDetails.TypeSystemVersion;
- }
-
- return connectionBuilder.ToString();
- }
-
- ///
- /// Change the database context of a connection.
- ///
- /// URI of the owner of the connection
- /// Name of the database to change the connection to
- public void ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName)
- {
- ConnectionInfo info;
- if (TryFindConnection(ownerUri, out info))
- {
- try
- {
- foreach (DbConnection connection in info.AllConnections)
- {
- if (connection.State == ConnectionState.Open)
- {
- connection.ChangeDatabase(newDatabaseName);
- }
- }
-
- info.ConnectionDetails.DatabaseName = newDatabaseName;
-
- // Fire a connection changed event
- ConnectionChangedParams parameters = new ConnectionChangedParams();
- ConnectionSummary summary = info.ConnectionDetails;
- parameters.Connection = summary.Clone();
- parameters.OwnerUri = ownerUri;
- ServiceHost.SendEvent(ConnectionChangedNotification.Type, parameters);
- }
- catch (Exception e)
- {
- Logger.Write(
- LogLevel.Error,
- string.Format(
- "Exception caught while trying to change database context to [{0}] for OwnerUri [{1}]. Exception:{2}",
- newDatabaseName,
- ownerUri,
- e.ToString())
- );
- }
- }
- }
-
- ///
- /// Invokes the initial on-connect activities if the provided ConnectParams represents the default
- /// connection.
- ///
- private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams)
- {
- if (connectParams.Type != ConnectionType.Default)
- {
- return;
- }
-
- foreach (var activity in this.onConnectionActivities)
- {
- // not awaiting here to allow handlers to run in the background
- activity(connectionInfo);
- }
- }
-
- ///
- /// Invokes the final on-disconnect activities if the provided DisconnectParams represents the default
- /// connection or is null - representing that all connections are being disconnected.
- ///
- private void InvokeOnDisconnectionActivities(ConnectionInfo connectionInfo)
- {
- foreach (var activity in this.onDisconnectActivities)
- {
- activity(connectionInfo.ConnectionDetails, connectionInfo.OwnerUri);
- }
- }
-
- ///
- /// Handles the Telemetry events that occur upon disconnect.
- ///
- ///
- private void HandleDisconnectTelemetry(ConnectionInfo connectionInfo)
- {
- if (ServiceHost != null)
- {
- try
- {
- // Send a telemetry notification for intellisense performance metrics
- ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams()
- {
- Params = new TelemetryProperties
- {
- Properties = new Dictionary
- {
- { TelemetryPropertyNames.IsAzure, connectionInfo.IsAzure.ToOneOrZeroString() }
- },
- EventName = TelemetryEventNames.IntellisenseQuantile,
- Measures = connectionInfo.IntellisenseMetrics.Quantile
- }
- });
- }
- catch (Exception ex)
- {
- Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString());
- }
- }
- }
- }
-}
+ {
+ // Validate parameters
+ if (disconnectParams == null || string.IsNullOrEmpty(disconnectParams.OwnerUri))
+ {
+ return false;
+ }
+
+ // Cancel if we are in the middle of connecting
+ if (CancelConnections(disconnectParams.OwnerUri, disconnectParams.Type))
+ {
+ return false;
+ }
+
+ // Lookup the ConnectionInfo owned by the URI
+ ConnectionInfo info;
+ if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info))
+ {
+ return false;
+ }
+
+ // Call Close() on the connections we want to disconnect
+ // If no connections were located, return false
+ if (!CloseConnections(info, disconnectParams.Type))
+ {
+ return false;
+ }
+
+ // Remove the disconnected connections from the ConnectionInfo map
+ if (disconnectParams.Type == null)
+ {
+ info.RemoveAllConnections();
+ }
+ else
+ {
+ info.RemoveConnection(disconnectParams.Type);
+ }
+
+ // If the ConnectionInfo has no more connections, remove the ConnectionInfo
+ if (info.CountConnections == 0)
+ {
+ ownerToConnectionMap.Remove(disconnectParams.OwnerUri);
+ }
+
+ // Handle Telemetry disconnect events if we are disconnecting the default connection
+ if (disconnectParams.Type == null || disconnectParams.Type == ConnectionType.Default)
+ {
+ HandleDisconnectTelemetry(info);
+ InvokeOnDisconnectionActivities(info);
+ }
+
+ // Return true upon success
+ return true;
+ }
+
+ ///
+ /// Cancel connections associated with the given ownerUri.
+ /// If connectionType is not null, cancel the connection with the given connectionType
+ /// If connectionType is null, cancel all pending connections associated with ownerUri.
+ ///
+ /// true if a single pending connection associated with the non-null connectionType was
+ /// found and cancelled, false otherwise
+ private bool CancelConnections(string ownerUri, string connectionType)
+ {
+ // Cancel the connection of the given type
+ if (connectionType != null)
+ {
+ // If we are trying to disconnect a specific connection and it was just cancelled,
+ // this will return true
+ return CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = connectionType });
+ }
+
+ // Cancel all pending connections
+ foreach (var entry in cancelTupleToCancellationTokenSourceMap)
+ {
+ string entryConnectionUri = entry.Key.OwnerUri;
+ string entryConnectionType = entry.Key.Type;
+ if (ownerUri.Equals(entryConnectionUri))
+ {
+ CancelConnect(new CancelConnectParams() { OwnerUri = ownerUri, Type = entryConnectionType });
+ }
+ }
+
+ return false;
+ }
+
+ ///
+ /// Closes DbConnections associated with the given ConnectionInfo.
+ /// If connectionType is not null, closes the DbConnection with the type given by connectionType.
+ /// If connectionType is null, closes all DbConnections.
+ ///
+ /// true if connections were found and attempted to be closed,
+ /// false if no connections were found
+ private bool CloseConnections(ConnectionInfo connectionInfo, string connectionType)
+ {
+ ICollection connectionsToDisconnect = new List();
+ if (connectionType == null)
+ {
+ connectionsToDisconnect = connectionInfo.AllConnections;
+ }
+ else
+ {
+ // Make sure there is an existing connection of this type
+ DbConnection connection;
+ if (!connectionInfo.TryGetConnection(connectionType, out connection))
+ {
+ return false;
+ }
+ connectionsToDisconnect.Add(connection);
+ }
+
+ if (connectionsToDisconnect.Count == 0)
+ {
+ return false;
+ }
+
+ foreach (DbConnection connection in connectionsToDisconnect)
+ {
+ try
+ {
+ connection.Close();
+ }
+ catch (Exception)
+ {
+ // Ignore
+ }
+ }
+
+ return true;
+ }
+
+ ///
+ /// List all databases on the server specified
+ ///
+ public ListDatabasesResponse ListDatabases(ListDatabasesParams listDatabasesParams)
+ {
+ // Verify parameters
+ var owner = listDatabasesParams.OwnerUri;
+ if (string.IsNullOrEmpty(owner))
+ {
+ throw new ArgumentException(SR.ConnectionServiceListDbErrorNullOwnerUri);
+ }
+
+ // Use the existing connection as a base for the search
+ ConnectionInfo info;
+ if (!TryFindConnection(owner, out info))
+ {
+ throw new Exception(SR.ConnectionServiceListDbErrorNotConnected(owner));
+ }
+ ConnectionDetails connectionDetails = info.ConnectionDetails.Clone();
+
+ // Connect to master and query sys.databases
+ connectionDetails.DatabaseName = "master";
+ var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails));
+ connection.Open();
+
+ List results = new List();
+ var systemDatabases = new[] {"master", "model", "msdb", "tempdb"};
+ using (DbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = "SELECT name FROM sys.databases ORDER BY name ASC";
+ command.CommandTimeout = 15;
+ command.CommandType = CommandType.Text;
+
+ using (var reader = command.ExecuteReader())
+ {
+ while (reader.Read())
+ {
+ results.Add(reader[0].ToString());
+ }
+ }
+ }
+
+ // Put system databases at the top of the list
+ results =
+ results.Where(s => systemDatabases.Any(s.Equals)).Concat(
+ results.Where(s => systemDatabases.All(x => !s.Equals(x)))).ToList();
+
+ connection.Close();
+
+ ListDatabasesResponse response = new ListDatabasesResponse();
+ response.DatabaseNames = results.ToArray();
+
+ return response;
+ }
+
+ public void InitializeService(IProtocolEndpoint serviceHost)
+ {
+ this.ServiceHost = serviceHost;
+
+ // Register request and event handlers with the Service Host
+ serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest);
+ serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest);
+ serviceHost.SetRequestHandler(DisconnectRequest.Type, HandleDisconnectRequest);
+ serviceHost.SetRequestHandler(ListDatabasesRequest.Type, HandleListDatabasesRequest);
+
+ // Register the configuration update handler
+ WorkspaceService.Instance.RegisterConfigChangeCallback(HandleDidChangeConfigurationNotification);
+ }
+
+ ///
+ /// Add a new method to be called when the onconnection request is submitted
+ ///
+ ///
+ public void RegisterOnConnectionTask(OnConnectionHandler activity)
+ {
+ onConnectionActivities.Add(activity);
+ }
+
+ ///
+ /// Add a new method to be called when the ondisconnect request is submitted
+ ///
+ public void RegisterOnDisconnectTask(OnDisconnectHandler activity)
+ {
+ onDisconnectActivities.Add(activity);
+ }
+
+ ///
+ /// Handle new connection requests
+ ///
+ ///
+ ///
+ ///
+ protected async Task HandleConnectRequest(
+ ConnectParams connectParams,
+ RequestContext requestContext)
+ {
+ Logger.Write(LogLevel.Verbose, "HandleConnectRequest");
+
+ try
+ {
+ RunConnectRequestHandlerTask(connectParams);
+ await requestContext.SendResult(true);
+ }
+ catch
+ {
+ await requestContext.SendResult(false);
+ }
+ }
+
+ private void RunConnectRequestHandlerTask(ConnectParams connectParams)
+ {
+ // create a task to connect asynchronously so that other requests are not blocked in the meantime
+ Task.Run(async () =>
+ {
+ try
+ {
+ // result is null if the ConnectParams was successfully validated
+ ConnectionCompleteParams result = ValidateConnectParams(connectParams);
+ if (result != null)
+ {
+ await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
+ return;
+ }
+
+ // open connection based on request details
+ result = await Connect(connectParams);
+ await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
+ }
+ catch (Exception ex)
+ {
+ ConnectionCompleteParams result = new ConnectionCompleteParams()
+ {
+ Messages = ex.ToString()
+ };
+ await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result);
+ }
+ });
+ }
+
+ ///
+ /// Handle cancel connect requests
+ ///
+ protected async Task HandleCancelConnectRequest(
+ CancelConnectParams cancelParams,
+ RequestContext requestContext)
+ {
+ Logger.Write(LogLevel.Verbose, "HandleCancelConnectRequest");
+
+ try
+ {
+ bool result = CancelConnect(cancelParams);
+ await requestContext.SendResult(result);
+ }
+ catch(Exception ex)
+ {
+ await requestContext.SendError(ex.ToString());
+ }
+ }
+
+ ///
+ /// Handle disconnect requests
+ ///
+ protected async Task HandleDisconnectRequest(
+ DisconnectParams disconnectParams,
+ RequestContext requestContext)
+ {
+ Logger.Write(LogLevel.Verbose, "HandleDisconnectRequest");
+
+ try
+ {
+ bool result = Instance.Disconnect(disconnectParams);
+ await requestContext.SendResult(result);
+ }
+ catch(Exception ex)
+ {
+ await requestContext.SendError(ex.ToString());
+ }
+
+ }
+
+ ///
+ /// Handle requests to list databases on the current server
+ ///
+ protected async Task HandleListDatabasesRequest(
+ ListDatabasesParams listDatabasesParams,
+ RequestContext requestContext)
+ {
+ Logger.Write(LogLevel.Verbose, "ListDatabasesRequest");
+
+ try
+ {
+ ListDatabasesResponse result = Instance.ListDatabases(listDatabasesParams);
+ await requestContext.SendResult(result);
+ }
+ catch(Exception ex)
+ {
+ await requestContext.SendError(ex.ToString());
+ }
+ }
+
+ public Task HandleDidChangeConfigurationNotification(
+ SqlToolsSettings newSettings,
+ SqlToolsSettings oldSettings,
+ EventContext eventContext)
+ {
+ return Task.FromResult(true);
+ }
+
+ ///
+ /// Build a connection string from a connection details instance
+ ///
+ ///
+ public static string BuildConnectionString(ConnectionDetails connectionDetails)
+ {
+ SqlConnectionStringBuilder connectionBuilder;
+
+ // If connectionDetails has a connection string already, just validate and return it
+ if (!string.IsNullOrEmpty(connectionDetails.ConnectionString))
+ {
+ connectionBuilder = new SqlConnectionStringBuilder(connectionDetails.ConnectionString);
+ return connectionBuilder.ToString();
+ }
+
+ connectionBuilder = new SqlConnectionStringBuilder
+ {
+ ["Data Source"] = connectionDetails.ServerName,
+ ["User Id"] = connectionDetails.UserName,
+ ["Password"] = connectionDetails.Password
+ };
+
+ // Check for any optional parameters
+ if (!string.IsNullOrEmpty(connectionDetails.DatabaseName))
+ {
+ connectionBuilder["Initial Catalog"] = connectionDetails.DatabaseName;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.AuthenticationType))
+ {
+ switch(connectionDetails.AuthenticationType)
+ {
+ case "Integrated":
+ connectionBuilder.IntegratedSecurity = true;
+ break;
+ case "SqlLogin":
+ break;
+ default:
+ throw new ArgumentException(SR.ConnectionServiceConnStringInvalidAuthType(connectionDetails.AuthenticationType));
+ }
+ }
+ if (connectionDetails.Encrypt.HasValue)
+ {
+ connectionBuilder.Encrypt = connectionDetails.Encrypt.Value;
+ }
+ if (connectionDetails.TrustServerCertificate.HasValue)
+ {
+ connectionBuilder.TrustServerCertificate = connectionDetails.TrustServerCertificate.Value;
+ }
+ if (connectionDetails.PersistSecurityInfo.HasValue)
+ {
+ connectionBuilder.PersistSecurityInfo = connectionDetails.PersistSecurityInfo.Value;
+ }
+ if (connectionDetails.ConnectTimeout.HasValue)
+ {
+ connectionBuilder.ConnectTimeout = connectionDetails.ConnectTimeout.Value;
+ }
+ if (connectionDetails.ConnectRetryCount.HasValue)
+ {
+ connectionBuilder.ConnectRetryCount = connectionDetails.ConnectRetryCount.Value;
+ }
+ if (connectionDetails.ConnectRetryInterval.HasValue)
+ {
+ connectionBuilder.ConnectRetryInterval = connectionDetails.ConnectRetryInterval.Value;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.ApplicationName))
+ {
+ connectionBuilder.ApplicationName = connectionDetails.ApplicationName;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.WorkstationId))
+ {
+ connectionBuilder.WorkstationID = connectionDetails.WorkstationId;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.ApplicationIntent))
+ {
+ ApplicationIntent intent;
+ switch (connectionDetails.ApplicationIntent)
+ {
+ case "ReadOnly":
+ intent = ApplicationIntent.ReadOnly;
+ break;
+ case "ReadWrite":
+ intent = ApplicationIntent.ReadWrite;
+ break;
+ default:
+ throw new ArgumentException(SR.ConnectionServiceConnStringInvalidIntent(connectionDetails.ApplicationIntent));
+ }
+ connectionBuilder.ApplicationIntent = intent;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.CurrentLanguage))
+ {
+ connectionBuilder.CurrentLanguage = connectionDetails.CurrentLanguage;
+ }
+ if (connectionDetails.Pooling.HasValue)
+ {
+ connectionBuilder.Pooling = connectionDetails.Pooling.Value;
+ }
+ if (connectionDetails.MaxPoolSize.HasValue)
+ {
+ connectionBuilder.MaxPoolSize = connectionDetails.MaxPoolSize.Value;
+ }
+ if (connectionDetails.MinPoolSize.HasValue)
+ {
+ connectionBuilder.MinPoolSize = connectionDetails.MinPoolSize.Value;
+ }
+ if (connectionDetails.LoadBalanceTimeout.HasValue)
+ {
+ connectionBuilder.LoadBalanceTimeout = connectionDetails.LoadBalanceTimeout.Value;
+ }
+ if (connectionDetails.Replication.HasValue)
+ {
+ connectionBuilder.Replication = connectionDetails.Replication.Value;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.AttachDbFilename))
+ {
+ connectionBuilder.AttachDBFilename = connectionDetails.AttachDbFilename;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.FailoverPartner))
+ {
+ connectionBuilder.FailoverPartner = connectionDetails.FailoverPartner;
+ }
+ if (connectionDetails.MultiSubnetFailover.HasValue)
+ {
+ connectionBuilder.MultiSubnetFailover = connectionDetails.MultiSubnetFailover.Value;
+ }
+ if (connectionDetails.MultipleActiveResultSets.HasValue)
+ {
+ connectionBuilder.MultipleActiveResultSets = connectionDetails.MultipleActiveResultSets.Value;
+ }
+ if (connectionDetails.PacketSize.HasValue)
+ {
+ connectionBuilder.PacketSize = connectionDetails.PacketSize.Value;
+ }
+ if (!string.IsNullOrEmpty(connectionDetails.TypeSystemVersion))
+ {
+ connectionBuilder.TypeSystemVersion = connectionDetails.TypeSystemVersion;
+ }
+
+ return connectionBuilder.ToString();
+ }
+
+ ///
+ /// Change the database context of a connection.
+ ///
+ /// URI of the owner of the connection
+ /// Name of the database to change the connection to
+ public void ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName)
+ {
+ ConnectionInfo info;
+ if (TryFindConnection(ownerUri, out info))
+ {
+ try
+ {
+ foreach (DbConnection connection in info.AllConnections)
+ {
+ if (connection.State == ConnectionState.Open)
+ {
+ connection.ChangeDatabase(newDatabaseName);
+ }
+ }
+
+ info.ConnectionDetails.DatabaseName = newDatabaseName;
+
+ // Fire a connection changed event
+ ConnectionChangedParams parameters = new ConnectionChangedParams();
+ ConnectionSummary summary = info.ConnectionDetails;
+ parameters.Connection = summary.Clone();
+ parameters.OwnerUri = ownerUri;
+ ServiceHost.SendEvent(ConnectionChangedNotification.Type, parameters);
+ }
+ catch (Exception e)
+ {
+ Logger.Write(
+ LogLevel.Error,
+ string.Format(
+ "Exception caught while trying to change database context to [{0}] for OwnerUri [{1}]. Exception:{2}",
+ newDatabaseName,
+ ownerUri,
+ e.ToString())
+ );
+ }
+ }
+ }
+
+ ///
+ /// Invokes the initial on-connect activities if the provided ConnectParams represents the default
+ /// connection.
+ ///
+ private void InvokeOnConnectionActivities(ConnectionInfo connectionInfo, ConnectParams connectParams)
+ {
+ if (connectParams.Type != ConnectionType.Default)
+ {
+ return;
+ }
+
+ foreach (var activity in this.onConnectionActivities)
+ {
+ // not awaiting here to allow handlers to run in the background
+ activity(connectionInfo);
+ }
+ }
+
+ ///
+ /// Invokes the final on-disconnect activities if the provided DisconnectParams represents the default
+ /// connection or is null - representing that all connections are being disconnected.
+ ///
+ private void InvokeOnDisconnectionActivities(ConnectionInfo connectionInfo)
+ {
+ foreach (var activity in this.onDisconnectActivities)
+ {
+ activity(connectionInfo.ConnectionDetails, connectionInfo.OwnerUri);
+ }
+ }
+
+ ///
+ /// Handles the Telemetry events that occur upon disconnect.
+ ///
+ ///
+ private void HandleDisconnectTelemetry(ConnectionInfo connectionInfo)
+ {
+ if (ServiceHost != null)
+ {
+ try
+ {
+ // Send a telemetry notification for intellisense performance metrics
+ ServiceHost.SendEvent(TelemetryNotification.Type, new TelemetryParams()
+ {
+ Params = new TelemetryProperties
+ {
+ Properties = new Dictionary
+ {
+ { TelemetryPropertyNames.IsAzure, connectionInfo.IsAzure.ToOneOrZeroString() }
+ },
+ EventName = TelemetryEventNames.IntellisenseQuantile,
+ Measures = connectionInfo.IntellisenseMetrics.Quantile
+ }
+ });
+ }
+ catch (Exception ex)
+ {
+ Logger.Write(LogLevel.Verbose, "Could not send Connection telemetry event " + ex.ToString());
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
index 58d38ee4..3333ce00 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParams.cs
@@ -14,7 +14,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// A URI identifying the owner of the connection. This will most commonly be a file in the workspace
/// or a virtual file representing an object in a database.
///
- public string OwnerUri { get; set; }
+ public string OwnerUri { get; set; }
+
///
/// Contains the required parameters to initialize a connection to a database.
/// A connection will identified by its server name, database name and user name.
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
index 7b581658..6b191dde 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectParamsExtensions.cs
@@ -24,6 +24,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
errorMessage = SR.ConnectionParamsValidateNullConnection;
}
+ else if (!string.IsNullOrEmpty(parameters.Connection.ConnectionString))
+ {
+ // Do not check other connection parameters if a connection string is present
+ return string.IsNullOrEmpty(errorMessage);
+ }
else if (string.IsNullOrEmpty(parameters.Connection.ServerName))
{
errorMessage = SR.ConnectionParamsValidateNullServerName;
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
index 68dfa51f..57312cbf 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
@@ -5,7 +5,7 @@
using System;
using System.Collections.Generic;
-using System.Globalization;
+using System.Globalization;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
@@ -443,6 +443,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
}
}
+ ///
+ /// Gets or sets a string value to be used as the connection string. If given, all other options will be ignored.
+ ///
+ public string ConnectionString
+ {
+ get
+ {
+ return GetOptionValue("connectionString");
+ }
+
+ set
+ {
+ SetOptionValue("connectionString", value);
+ }
+ }
+
private T GetOptionValue(string name)
{
T result = default(T);
@@ -485,33 +501,33 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
Options.Add(name, value);
}
- }
-
- public bool IsComparableTo(ConnectionDetails other)
- {
- if (other == null)
- {
- return false;
- }
-
- if (!string.Equals(ServerName, other.ServerName)
- || !string.Equals(AuthenticationType, other.AuthenticationType)
- || !string.Equals(UserName, other.UserName))
- {
- return false;
- }
-
- // For database name, only compare if neither is empty. This is important
- // Since it allows for handling of connections to the default database, but is
- // not a 100% accurate heuristic.
- if (!string.IsNullOrEmpty(DatabaseName)
- && !string.IsNullOrEmpty(other.DatabaseName)
- && !string.Equals(DatabaseName, other.DatabaseName))
- {
- return false;
- }
-
- return true;
- }
+ }
+
+ public bool IsComparableTo(ConnectionDetails other)
+ {
+ if (other == null)
+ {
+ return false;
+ }
+
+ if (ServerName != other.ServerName
+ || AuthenticationType != other.AuthenticationType
+ || UserName != other.UserName)
+ {
+ return false;
+ }
+
+ // For database name, only compare if neither is empty. This is important
+ // Since it allows for handling of connections to the default database, but is
+ // not a 100% accurate heuristic.
+ if (!string.IsNullOrEmpty(DatabaseName)
+ && !string.IsNullOrEmpty(other.DatabaseName)
+ && DatabaseName != other.DatabaseName)
+ {
+ return false;
+ }
+
+ return true;
+ }
}
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
index 690f7577..9e279b60 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
@@ -325,10 +325,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
.Returns((string connString) =>
{
dummySqlConnection.ConnectionString = connString;
- SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
-
- // Database name is respected. Follow heuristic where empty DB name really means Master
- var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog;
+ SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
+
+ // Database name is respected. Follow heuristic where empty DB name really means Master
+ var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog;
dummySqlConnection.SetDatabase(dbName);
return dummySqlConnection;
});
@@ -1210,5 +1210,37 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
await Assert.ThrowsAsync(
() => service.GetOrOpenConnection(TestObjects.ScriptUri, ConnectionType.Query));
}
+
+ [Fact]
+ public async Task ConnectionWithConnectionStringSucceeds()
+ {
+ var connectionParameters = TestObjects.GetTestConnectionParams(true);
+ var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters);
+
+ Assert.NotEmpty(connectionResult.ConnectionId);
+ }
+
+ [Fact]
+ public async Task ConnectionWithBadConnectionStringFails()
+ {
+ var connectionParameters = TestObjects.GetTestConnectionParams(true);
+ connectionParameters.Connection.ConnectionString = "thisisnotavalidconnectionstring";
+ var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters);
+
+ Assert.NotEmpty(connectionResult.ErrorMessage);
+ }
+
+ [Fact]
+ public async Task ConnectionWithConnectionStringOverridesParameters()
+ {
+ var connectionParameters = TestObjects.GetTestConnectionParams();
+ connectionParameters.Connection.ServerName = "overriddenServerName";
+ var connectionString = TestObjects.GetTestConnectionParams(true).Connection.ConnectionString;
+ connectionParameters.Connection.ConnectionString = connectionString;
+
+ // Connect and verify that the server name has been overridden
+ var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters);
+ Assert.NotEqual(connectionParameters.Connection.ServerName, connectionResult.ConnectionSummary.ServerName);
+ }
}
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs
index 240877e2..c9cf078e 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs
@@ -44,12 +44,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
GetTestConnectionDetails());
}
- public static ConnectParams GetTestConnectionParams()
+ public static ConnectParams GetTestConnectionParams(bool useConnectionString = false)
{
return new ConnectParams()
{
OwnerUri = ScriptUri,
- Connection = GetTestConnectionDetails()
+ Connection = GetTestConnectionDetails(useConnectionString)
};
}
@@ -66,8 +66,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
ServerEdition = "Developer Edition",
ServerLevel = ""
};
- }
-
+ }
+
///
/// Creates a test sql connection factory instance
///
@@ -80,8 +80,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
///
/// Creates a test connection details object
///
- public static ConnectionDetails GetTestConnectionDetails()
+ public static ConnectionDetails GetTestConnectionDetails(bool useConnectionString = false)
{
+ if (useConnectionString)
+ {
+ return new ConnectionDetails()
+ {
+ ConnectionString = "User ID=user;PWD=password;Database=databaseName;Server=serverName"
+ };
+ }
+
return new ConnectionDetails()
{
UserName = "user",
@@ -214,9 +222,9 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
}
public override string ConnectionString { get; set; }
- public override string Database
+ public override string Database
{
- get { return _database; }
+ get { return _database; }
}
public override ConnectionState State { get; }
@@ -233,13 +241,13 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
// No Op
}
- ///
- /// Test helper method to set the database value
- ///
+ ///
+ /// Test helper method to set the database value
+ ///
///
- public void SetDatabase(string database)
- {
- this._database = database;
+ public void SetDatabase(string database)
+ {
+ this._database = database;
}
}