diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs index 2d82f3ce..457fa985 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -99,6 +99,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } + public bool HasConnectionType(string connectionType) + { + connectionType = connectionType ?? ConnectionType.Default; + return ConnectionTypeToConnectionMap.ContainsKey(connectionType); + } + /// /// The count of DbConnectioninstances held by this ConnectionInfo /// @@ -155,7 +161,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection DbConnection connection; ConnectionTypeToConnectionMap.TryRemove(type, out connection); } - } - + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index ae846247..9901952b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -188,44 +188,78 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // If there is no ConnectionInfo in the map, create a new ConnectionInfo, // but wait until later when we are connected to add it to the map. ConnectionInfo connectionInfo; + bool connectionChanged = false; if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo)) { connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); } + else if (IsConnectionChanged(connectionParams, connectionInfo)) + { + // We are actively changing the connection information for this connection. We must disconnect + // all active connections, since it represents a full context change + connectionChanged = true; + } - // Resolve if it is an existing connection - // Disconnect active connection if the URI is already connected for this connection type + DisconnectExistingConnectionIfNeeded(connectionParams, connectionInfo, disconnectAll: connectionChanged); + + if (connectionChanged) + { + connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); + } + + // Try to open a connection with the given ConnectParams + ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams); + if (response != null) + { + return response; + } + + // If this is the first connection for this URI, add the ConnectionInfo to the map + bool addToMap = connectionChanged || !ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri); + if (addToMap) + { + ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; + } + + // Return information about the connected SQL Server instance + ConnectionCompleteParams completeParams = GetConnectionCompleteParams(connectionParams.Type, connectionInfo); + // Invoke callback notifications + InvokeOnConnectionActivities(connectionInfo, connectionParams); + + return completeParams; + } + + private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo) + { + if (connectionInfo.HasConnectionType(connectionParams.Type) + && !connectionInfo.ConnectionDetails.IsComparableTo(connectionParams.Connection)) + { + return true; + } + return false; + } + + private bool IsDefaultConnectionType(string connectionType) + { + return string.IsNullOrEmpty(connectionType) || ConnectionType.Default.Equals(connectionType, StringComparison.CurrentCultureIgnoreCase); + } + + private void DisconnectExistingConnectionIfNeeded(ConnectParams connectionParams, ConnectionInfo connectionInfo, bool disconnectAll) + { + // Resolve if it is an existing connection + // Disconnect active connection if the URI is already connected for this connection type DbConnection existingConnection; if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection)) { var disconnectParams = new DisconnectParams() { - OwnerUri = connectionParams.OwnerUri, - Type = connectionParams.Type + OwnerUri = connectionParams.OwnerUri, + Type = disconnectAll ? null : connectionParams.Type }; Disconnect(disconnectParams); - } - - // Try to open a connection with the given ConnectParams - ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams); - if (response != null) - { - return response; - } - - // If this is the first connection for this URI, add the ConnectionInfo to the map - if (!ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri)) - { - ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; - } - - // Invoke callback notifications - InvokeOnConnectionActivities(connectionInfo, connectionParams); - - // Return information about the connected SQL Server instance - return GetConnectionCompleteParams(connectionParams.Type, connectionInfo); - } - + } + } + /// /// Creates a ConnectionCompleteParams as a response to a successful connection. /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo. @@ -304,46 +338,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connectionInfo.AddConnection(connectionParams.Type, connection); // Add a cancellation token source so that the connection OpenAsync() can be cancelled - using (source = new CancellationTokenSource()) + source = new CancellationTokenSource(); + // Locking here to perform two operations as one atomic operation + lock (cancellationTokenSourceLock) { - // Locking here to perform two operations as one atomic operation - lock (cancellationTokenSourceLock) + // If the URI is currently connecting from a different request, cancel it before we try to connect + CancellationTokenSource currentSource; + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource)) { - // If the URI is currently connecting from a different request, cancel it before we try to connect - CancellationTokenSource currentSource; - if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource)) - { - currentSource.Cancel(); - } - cancelTupleToCancellationTokenSourceMap[cancelKey] = source; + currentSource.Cancel(); } - - // Create a task to handle cancellation requests - var cancellationTask = Task.Run(() => - { - source.Token.WaitHandle.WaitOne(); - try - { - source.Token.ThrowIfCancellationRequested(); - } - catch (ObjectDisposedException) - { - // If ObjectDisposedException was thrown, then execution has already exited the - // "using" statment and source was disposed, meaning that the openTask completed - // successfully. This results in a ObjectDisposedException when trying to access - // source.Token and should be ignored. - } - }); - - var openTask = Task.Run(async () => - { - await connection.OpenAsync(source.Token); - }); - - // Open the connection - await Task.WhenAny(openTask, cancellationTask).Unwrap(); - source.Cancel(); + cancelTupleToCancellationTokenSourceMap[cancelKey] = source; } + + // Open the connection + await connection.OpenAsync(source.Token); } catch (SqlException ex) { @@ -376,6 +385,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue); } + source?.Dispose(); } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs index bdce4fcb..68dfa51f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -5,7 +5,7 @@ using System; using System.Collections.Generic; -using System.Globalization; +using System.Globalization; using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts @@ -485,6 +485,33 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { Options.Add(name, value); } - } + } + + public bool IsComparableTo(ConnectionDetails other) + { + if (other == null) + { + return false; + } + + if (!string.Equals(ServerName, other.ServerName) + || !string.Equals(AuthenticationType, other.AuthenticationType) + || !string.Equals(UserName, other.UserName)) + { + return false; + } + + // For database name, only compare if neither is empty. This is important + // Since it allows for handling of connections to the default database, but is + // not a 100% accurate heuristic. + if (!string.IsNullOrEmpty(DatabaseName) + && !string.IsNullOrEmpty(other.DatabaseName) + && !string.Equals(DatabaseName, other.DatabaseName)) + { + return false; + } + + return true; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 8efe3984..690f7577 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -314,9 +314,37 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection { bool callbackInvoked = false; - // first connect string ownerUri = "file://my/sample/file.sql"; - var connectionService = TestObjects.GetTestConnectionService(); + const string masterDbName = "master"; + const string otherDbName = "other"; + // Given a connection that returns the database name + var dummySqlConnection = new TestSqlConnection(null); + + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns((string connString) => + { + dummySqlConnection.ConnectionString = connString; + SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); + + // Database name is respected. Follow heuristic where empty DB name really means Master + var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; + dummySqlConnection.SetDatabase(dbName); + return dummySqlConnection; + }); + + var connectionService = new ConnectionService(mockFactory.Object); + + // register disconnect callback + connectionService.RegisterOnDisconnectTask( + (result, uri) => { + callbackInvoked = true; + Assert.True(uri.Equals(ownerUri)); + return Task.FromResult(true); + } + ); + + // When I connect to default var connectionResult = await connectionService .Connect(new ConnectParams() @@ -325,32 +353,27 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Connection = TestObjects.GetTestConnectionDetails() }); - // verify that we are connected + // Then I expect to be connected to master Assert.NotEmpty(connectionResult.ConnectionId); - // register disconnect callback - connectionService.RegisterOnDisconnectTask( - (result, uri) => { - callbackInvoked = true; - Assert.True(uri.Equals(ownerUri)); - return Task.FromResult(true); - } - ); - - // send annother connect request + // And when I then connect to another DB + var updatedConnectionDetails = TestObjects.GetTestConnectionDetails(); + updatedConnectionDetails.DatabaseName = otherDbName; connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, - Connection = TestObjects.GetTestConnectionDetails() + Connection = updatedConnectionDetails }); + // Then I expect to be disconnected from master, and connected to the new DB // verify that the event was fired (we disconnected first before connecting) Assert.True(callbackInvoked); // verify that we connected again Assert.NotEmpty(connectionResult.ConnectionId); + Assert.Equal(otherDbName, connectionResult.ConnectionSummary.DatabaseName); } /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs index ef405469..a8d5b429 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs @@ -161,6 +161,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility /// public class TestSqlConnection : DbConnection { + private string _database; + internal TestSqlConnection(TestResultSet[] data) { Data = data; @@ -188,7 +190,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility } public override string ConnectionString { get; set; } - public override string Database { get; } + public override string Database + { + get { return _database; } + } + public override ConnectionState State { get; } public override string DataSource { get; } public override string ServerVersion { get; } @@ -202,6 +208,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { // No Op } + + /// + /// Test helper method to set the database value + /// + /// + public void SetDatabase(string database) + { + this._database = database; + } } ///