From f3bf330da6548ddb9c1c0e8f29988e8dbdeef235 Mon Sep 17 00:00:00 2001 From: Kevin Cunnane Date: Thu, 6 Apr 2017 11:25:59 -0700 Subject: [PATCH] Connect with different properties should actually change context (#307) * Connect with different properties should actually change context - Up to now, calling Connect for a previously-connected URI would disconnect, then reconnect ot the original (not new) target. WIth these changes we handle changes to database name or other key properties by updating the ConnectionInfo and connecting to the new target - Some interesting scenarios are raised by our API, notably that an empty database name maps to the default DB (which we know nothing about). This limits the new feature such that only if the DB Name is specified, we'll change the connection. Hence 2 calls to an empty DB will not result in a DB change. Additional changes: - After discussion with Ben, we're simplifying the cancellation logic. He had made changes to support this, so the main update is that we dispose the token in the final block after its last use (hence avoiding a disposed exception) and clean up the number of Waits required since we already have async cancellation support - Factored some logic such that the OnConnection callback isn't invoked until after we've updated the database name in the GetConnectionCompleteParams method. Again, this supports reporting the actual DB name instead of leaving it blank for default DB requests. * PR comment fixes --- .../Connection/ConnectionInfo.cs | 9 +- .../Connection/ConnectionService.cs | 134 ++++++++++-------- .../Connection/Contracts/ConnectionDetails.cs | 31 +++- .../Connection/ConnectionServiceTests.cs | 51 +++++-- .../Utility/TestObjects.cs | 17 ++- 5 files changed, 161 insertions(+), 81 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs index 2d82f3ce..457fa985 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs @@ -99,6 +99,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection } } + public bool HasConnectionType(string connectionType) + { + connectionType = connectionType ?? ConnectionType.Default; + return ConnectionTypeToConnectionMap.ContainsKey(connectionType); + } + /// /// The count of DbConnectioninstances held by this ConnectionInfo /// @@ -155,7 +161,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection DbConnection connection; ConnectionTypeToConnectionMap.TryRemove(type, out connection); } - } - + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index ae846247..9901952b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -188,44 +188,78 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // If there is no ConnectionInfo in the map, create a new ConnectionInfo, // but wait until later when we are connected to add it to the map. ConnectionInfo connectionInfo; + bool connectionChanged = false; if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo)) { connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); } + else if (IsConnectionChanged(connectionParams, connectionInfo)) + { + // We are actively changing the connection information for this connection. We must disconnect + // all active connections, since it represents a full context change + connectionChanged = true; + } - // Resolve if it is an existing connection - // Disconnect active connection if the URI is already connected for this connection type + DisconnectExistingConnectionIfNeeded(connectionParams, connectionInfo, disconnectAll: connectionChanged); + + if (connectionChanged) + { + connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); + } + + // Try to open a connection with the given ConnectParams + ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams); + if (response != null) + { + return response; + } + + // If this is the first connection for this URI, add the ConnectionInfo to the map + bool addToMap = connectionChanged || !ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri); + if (addToMap) + { + ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; + } + + // Return information about the connected SQL Server instance + ConnectionCompleteParams completeParams = GetConnectionCompleteParams(connectionParams.Type, connectionInfo); + // Invoke callback notifications + InvokeOnConnectionActivities(connectionInfo, connectionParams); + + return completeParams; + } + + private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo) + { + if (connectionInfo.HasConnectionType(connectionParams.Type) + && !connectionInfo.ConnectionDetails.IsComparableTo(connectionParams.Connection)) + { + return true; + } + return false; + } + + private bool IsDefaultConnectionType(string connectionType) + { + return string.IsNullOrEmpty(connectionType) || ConnectionType.Default.Equals(connectionType, StringComparison.CurrentCultureIgnoreCase); + } + + private void DisconnectExistingConnectionIfNeeded(ConnectParams connectionParams, ConnectionInfo connectionInfo, bool disconnectAll) + { + // Resolve if it is an existing connection + // Disconnect active connection if the URI is already connected for this connection type DbConnection existingConnection; if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection)) { var disconnectParams = new DisconnectParams() { - OwnerUri = connectionParams.OwnerUri, - Type = connectionParams.Type + OwnerUri = connectionParams.OwnerUri, + Type = disconnectAll ? null : connectionParams.Type }; Disconnect(disconnectParams); - } - - // Try to open a connection with the given ConnectParams - ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams); - if (response != null) - { - return response; - } - - // If this is the first connection for this URI, add the ConnectionInfo to the map - if (!ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri)) - { - ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; - } - - // Invoke callback notifications - InvokeOnConnectionActivities(connectionInfo, connectionParams); - - // Return information about the connected SQL Server instance - return GetConnectionCompleteParams(connectionParams.Type, connectionInfo); - } - + } + } + /// /// Creates a ConnectionCompleteParams as a response to a successful connection. /// Also sets the DatabaseName and IsAzure properties of ConnectionInfo. @@ -304,46 +338,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connectionInfo.AddConnection(connectionParams.Type, connection); // Add a cancellation token source so that the connection OpenAsync() can be cancelled - using (source = new CancellationTokenSource()) + source = new CancellationTokenSource(); + // Locking here to perform two operations as one atomic operation + lock (cancellationTokenSourceLock) { - // Locking here to perform two operations as one atomic operation - lock (cancellationTokenSourceLock) + // If the URI is currently connecting from a different request, cancel it before we try to connect + CancellationTokenSource currentSource; + if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource)) { - // If the URI is currently connecting from a different request, cancel it before we try to connect - CancellationTokenSource currentSource; - if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource)) - { - currentSource.Cancel(); - } - cancelTupleToCancellationTokenSourceMap[cancelKey] = source; + currentSource.Cancel(); } - - // Create a task to handle cancellation requests - var cancellationTask = Task.Run(() => - { - source.Token.WaitHandle.WaitOne(); - try - { - source.Token.ThrowIfCancellationRequested(); - } - catch (ObjectDisposedException) - { - // If ObjectDisposedException was thrown, then execution has already exited the - // "using" statment and source was disposed, meaning that the openTask completed - // successfully. This results in a ObjectDisposedException when trying to access - // source.Token and should be ignored. - } - }); - - var openTask = Task.Run(async () => - { - await connection.OpenAsync(source.Token); - }); - - // Open the connection - await Task.WhenAny(openTask, cancellationTask).Unwrap(); - source.Cancel(); + cancelTupleToCancellationTokenSourceMap[cancelKey] = source; } + + // Open the connection + await connection.OpenAsync(source.Token); } catch (SqlException ex) { @@ -376,6 +385,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue); } + source?.Dispose(); } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs index bdce4fcb..68dfa51f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs @@ -5,7 +5,7 @@ using System; using System.Collections.Generic; -using System.Globalization; +using System.Globalization; using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts @@ -485,6 +485,33 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { Options.Add(name, value); } - } + } + + public bool IsComparableTo(ConnectionDetails other) + { + if (other == null) + { + return false; + } + + if (!string.Equals(ServerName, other.ServerName) + || !string.Equals(AuthenticationType, other.AuthenticationType) + || !string.Equals(UserName, other.UserName)) + { + return false; + } + + // For database name, only compare if neither is empty. This is important + // Since it allows for handling of connections to the default database, but is + // not a 100% accurate heuristic. + if (!string.IsNullOrEmpty(DatabaseName) + && !string.IsNullOrEmpty(other.DatabaseName) + && !string.Equals(DatabaseName, other.DatabaseName)) + { + return false; + } + + return true; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 8efe3984..690f7577 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -314,9 +314,37 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection { bool callbackInvoked = false; - // first connect string ownerUri = "file://my/sample/file.sql"; - var connectionService = TestObjects.GetTestConnectionService(); + const string masterDbName = "master"; + const string otherDbName = "other"; + // Given a connection that returns the database name + var dummySqlConnection = new TestSqlConnection(null); + + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns((string connString) => + { + dummySqlConnection.ConnectionString = connString; + SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); + + // Database name is respected. Follow heuristic where empty DB name really means Master + var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; + dummySqlConnection.SetDatabase(dbName); + return dummySqlConnection; + }); + + var connectionService = new ConnectionService(mockFactory.Object); + + // register disconnect callback + connectionService.RegisterOnDisconnectTask( + (result, uri) => { + callbackInvoked = true; + Assert.True(uri.Equals(ownerUri)); + return Task.FromResult(true); + } + ); + + // When I connect to default var connectionResult = await connectionService .Connect(new ConnectParams() @@ -325,32 +353,27 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Connection = TestObjects.GetTestConnectionDetails() }); - // verify that we are connected + // Then I expect to be connected to master Assert.NotEmpty(connectionResult.ConnectionId); - // register disconnect callback - connectionService.RegisterOnDisconnectTask( - (result, uri) => { - callbackInvoked = true; - Assert.True(uri.Equals(ownerUri)); - return Task.FromResult(true); - } - ); - - // send annother connect request + // And when I then connect to another DB + var updatedConnectionDetails = TestObjects.GetTestConnectionDetails(); + updatedConnectionDetails.DatabaseName = otherDbName; connectionResult = await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, - Connection = TestObjects.GetTestConnectionDetails() + Connection = updatedConnectionDetails }); + // Then I expect to be disconnected from master, and connected to the new DB // verify that the event was fired (we disconnected first before connecting) Assert.True(callbackInvoked); // verify that we connected again Assert.NotEmpty(connectionResult.ConnectionId); + Assert.Equal(otherDbName, connectionResult.ConnectionSummary.DatabaseName); } /// diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs index ef405469..a8d5b429 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs @@ -161,6 +161,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility /// public class TestSqlConnection : DbConnection { + private string _database; + internal TestSqlConnection(TestResultSet[] data) { Data = data; @@ -188,7 +190,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility } public override string ConnectionString { get; set; } - public override string Database { get; } + public override string Database + { + get { return _database; } + } + public override ConnectionState State { get; } public override string DataSource { get; } public override string ServerVersion { get; } @@ -202,6 +208,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { // No Op } + + /// + /// Test helper method to set the database value + /// + /// + public void SetDatabase(string database) + { + this._database = database; + } } ///