diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
index 2d82f3ce..457fa985 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionInfo.cs
@@ -99,6 +99,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
}
}
+ public bool HasConnectionType(string connectionType)
+ {
+ connectionType = connectionType ?? ConnectionType.Default;
+ return ConnectionTypeToConnectionMap.ContainsKey(connectionType);
+ }
+
///
/// The count of DbConnectioninstances held by this ConnectionInfo
///
@@ -155,7 +161,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
DbConnection connection;
ConnectionTypeToConnectionMap.TryRemove(type, out connection);
}
- }
-
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
index ae846247..9901952b 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs
@@ -188,44 +188,78 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
// If there is no ConnectionInfo in the map, create a new ConnectionInfo,
// but wait until later when we are connected to add it to the map.
ConnectionInfo connectionInfo;
+ bool connectionChanged = false;
if (!ownerToConnectionMap.TryGetValue(connectionParams.OwnerUri, out connectionInfo))
{
connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
}
+ else if (IsConnectionChanged(connectionParams, connectionInfo))
+ {
+ // We are actively changing the connection information for this connection. We must disconnect
+ // all active connections, since it represents a full context change
+ connectionChanged = true;
+ }
- // Resolve if it is an existing connection
- // Disconnect active connection if the URI is already connected for this connection type
+ DisconnectExistingConnectionIfNeeded(connectionParams, connectionInfo, disconnectAll: connectionChanged);
+
+ if (connectionChanged)
+ {
+ connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection);
+ }
+
+ // Try to open a connection with the given ConnectParams
+ ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams);
+ if (response != null)
+ {
+ return response;
+ }
+
+ // If this is the first connection for this URI, add the ConnectionInfo to the map
+ bool addToMap = connectionChanged || !ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri);
+ if (addToMap)
+ {
+ ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
+ }
+
+ // Return information about the connected SQL Server instance
+ ConnectionCompleteParams completeParams = GetConnectionCompleteParams(connectionParams.Type, connectionInfo);
+ // Invoke callback notifications
+ InvokeOnConnectionActivities(connectionInfo, connectionParams);
+
+ return completeParams;
+ }
+
+ private bool IsConnectionChanged(ConnectParams connectionParams, ConnectionInfo connectionInfo)
+ {
+ if (connectionInfo.HasConnectionType(connectionParams.Type)
+ && !connectionInfo.ConnectionDetails.IsComparableTo(connectionParams.Connection))
+ {
+ return true;
+ }
+ return false;
+ }
+
+ private bool IsDefaultConnectionType(string connectionType)
+ {
+ return string.IsNullOrEmpty(connectionType) || ConnectionType.Default.Equals(connectionType, StringComparison.CurrentCultureIgnoreCase);
+ }
+
+ private void DisconnectExistingConnectionIfNeeded(ConnectParams connectionParams, ConnectionInfo connectionInfo, bool disconnectAll)
+ {
+ // Resolve if it is an existing connection
+ // Disconnect active connection if the URI is already connected for this connection type
DbConnection existingConnection;
if (connectionInfo.TryGetConnection(connectionParams.Type, out existingConnection))
{
var disconnectParams = new DisconnectParams()
{
- OwnerUri = connectionParams.OwnerUri,
- Type = connectionParams.Type
+ OwnerUri = connectionParams.OwnerUri,
+ Type = disconnectAll ? null : connectionParams.Type
};
Disconnect(disconnectParams);
- }
-
- // Try to open a connection with the given ConnectParams
- ConnectionCompleteParams response = await TryOpenConnection(connectionInfo, connectionParams);
- if (response != null)
- {
- return response;
- }
-
- // If this is the first connection for this URI, add the ConnectionInfo to the map
- if (!ownerToConnectionMap.ContainsKey(connectionParams.OwnerUri))
- {
- ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo;
- }
-
- // Invoke callback notifications
- InvokeOnConnectionActivities(connectionInfo, connectionParams);
-
- // Return information about the connected SQL Server instance
- return GetConnectionCompleteParams(connectionParams.Type, connectionInfo);
- }
-
+ }
+ }
+
///
/// Creates a ConnectionCompleteParams as a response to a successful connection.
/// Also sets the DatabaseName and IsAzure properties of ConnectionInfo.
@@ -304,46 +338,21 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
connectionInfo.AddConnection(connectionParams.Type, connection);
// Add a cancellation token source so that the connection OpenAsync() can be cancelled
- using (source = new CancellationTokenSource())
+ source = new CancellationTokenSource();
+ // Locking here to perform two operations as one atomic operation
+ lock (cancellationTokenSourceLock)
{
- // Locking here to perform two operations as one atomic operation
- lock (cancellationTokenSourceLock)
+ // If the URI is currently connecting from a different request, cancel it before we try to connect
+ CancellationTokenSource currentSource;
+ if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource))
{
- // If the URI is currently connecting from a different request, cancel it before we try to connect
- CancellationTokenSource currentSource;
- if (cancelTupleToCancellationTokenSourceMap.TryGetValue(cancelKey, out currentSource))
- {
- currentSource.Cancel();
- }
- cancelTupleToCancellationTokenSourceMap[cancelKey] = source;
+ currentSource.Cancel();
}
-
- // Create a task to handle cancellation requests
- var cancellationTask = Task.Run(() =>
- {
- source.Token.WaitHandle.WaitOne();
- try
- {
- source.Token.ThrowIfCancellationRequested();
- }
- catch (ObjectDisposedException)
- {
- // If ObjectDisposedException was thrown, then execution has already exited the
- // "using" statment and source was disposed, meaning that the openTask completed
- // successfully. This results in a ObjectDisposedException when trying to access
- // source.Token and should be ignored.
- }
- });
-
- var openTask = Task.Run(async () =>
- {
- await connection.OpenAsync(source.Token);
- });
-
- // Open the connection
- await Task.WhenAny(openTask, cancellationTask).Unwrap();
- source.Cancel();
+ cancelTupleToCancellationTokenSourceMap[cancelKey] = source;
}
+
+ // Open the connection
+ await connection.OpenAsync(source.Token);
}
catch (SqlException ex)
{
@@ -376,6 +385,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
cancelTupleToCancellationTokenSourceMap.TryRemove(cancelKey, out sourceValue);
}
+ source?.Dispose();
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
index bdce4fcb..68dfa51f 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionDetails.cs
@@ -5,7 +5,7 @@
using System;
using System.Collections.Generic;
-using System.Globalization;
+using System.Globalization;
using Microsoft.SqlTools.Utility;
namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
@@ -485,6 +485,33 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
Options.Add(name, value);
}
- }
+ }
+
+ public bool IsComparableTo(ConnectionDetails other)
+ {
+ if (other == null)
+ {
+ return false;
+ }
+
+ if (!string.Equals(ServerName, other.ServerName)
+ || !string.Equals(AuthenticationType, other.AuthenticationType)
+ || !string.Equals(UserName, other.UserName))
+ {
+ return false;
+ }
+
+ // For database name, only compare if neither is empty. This is important
+ // Since it allows for handling of connections to the default database, but is
+ // not a 100% accurate heuristic.
+ if (!string.IsNullOrEmpty(DatabaseName)
+ && !string.IsNullOrEmpty(other.DatabaseName)
+ && !string.Equals(DatabaseName, other.DatabaseName))
+ {
+ return false;
+ }
+
+ return true;
+ }
}
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
index 8efe3984..690f7577 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs
@@ -314,9 +314,37 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{
bool callbackInvoked = false;
- // first connect
string ownerUri = "file://my/sample/file.sql";
- var connectionService = TestObjects.GetTestConnectionService();
+ const string masterDbName = "master";
+ const string otherDbName = "other";
+ // Given a connection that returns the database name
+ var dummySqlConnection = new TestSqlConnection(null);
+
+ var mockFactory = new Mock();
+ mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny()))
+ .Returns((string connString) =>
+ {
+ dummySqlConnection.ConnectionString = connString;
+ SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
+
+ // Database name is respected. Follow heuristic where empty DB name really means Master
+ var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog;
+ dummySqlConnection.SetDatabase(dbName);
+ return dummySqlConnection;
+ });
+
+ var connectionService = new ConnectionService(mockFactory.Object);
+
+ // register disconnect callback
+ connectionService.RegisterOnDisconnectTask(
+ (result, uri) => {
+ callbackInvoked = true;
+ Assert.True(uri.Equals(ownerUri));
+ return Task.FromResult(true);
+ }
+ );
+
+ // When I connect to default
var connectionResult = await
connectionService
.Connect(new ConnectParams()
@@ -325,32 +353,27 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Connection = TestObjects.GetTestConnectionDetails()
});
- // verify that we are connected
+ // Then I expect to be connected to master
Assert.NotEmpty(connectionResult.ConnectionId);
- // register disconnect callback
- connectionService.RegisterOnDisconnectTask(
- (result, uri) => {
- callbackInvoked = true;
- Assert.True(uri.Equals(ownerUri));
- return Task.FromResult(true);
- }
- );
-
- // send annother connect request
+ // And when I then connect to another DB
+ var updatedConnectionDetails = TestObjects.GetTestConnectionDetails();
+ updatedConnectionDetails.DatabaseName = otherDbName;
connectionResult = await
connectionService
.Connect(new ConnectParams()
{
OwnerUri = ownerUri,
- Connection = TestObjects.GetTestConnectionDetails()
+ Connection = updatedConnectionDetails
});
+ // Then I expect to be disconnected from master, and connected to the new DB
// verify that the event was fired (we disconnected first before connecting)
Assert.True(callbackInvoked);
// verify that we connected again
Assert.NotEmpty(connectionResult.ConnectionId);
+ Assert.Equal(otherDbName, connectionResult.ConnectionSummary.DatabaseName);
}
///
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs
index ef405469..a8d5b429 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs
@@ -161,6 +161,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
///
public class TestSqlConnection : DbConnection
{
+ private string _database;
+
internal TestSqlConnection(TestResultSet[] data)
{
Data = data;
@@ -188,7 +190,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
}
public override string ConnectionString { get; set; }
- public override string Database { get; }
+ public override string Database
+ {
+ get { return _database; }
+ }
+
public override ConnectionState State { get; }
public override string DataSource { get; }
public override string ServerVersion { get; }
@@ -202,6 +208,15 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
{
// No Op
}
+
+ ///
+ /// Test helper method to set the database value
+ ///
+ ///
+ public void SetDatabase(string database)
+ {
+ this._database = database;
+ }
}
///