diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 1904033d..8d26f9a4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -1201,7 +1201,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection ChangeDatabaseParams changeDatabaseParams, RequestContext requestContext) { - await requestContext.SendResult(ChangeConnectionDatabaseContext(changeDatabaseParams.OwnerUri, changeDatabaseParams.NewDatabase)); + await requestContext.SendResult(ChangeConnectionDatabaseContext(changeDatabaseParams.OwnerUri, changeDatabaseParams.NewDatabase, true)); } /// @@ -1209,23 +1209,42 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// /// URI of the owner of the connection /// Name of the database to change the connection to - public bool ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName) + public bool ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName, bool force = false) { ConnectionInfo info; if (TryFindConnection(ownerUri, out info)) { try { - foreach (DbConnection connection in info.AllConnections) - { - if (connection.State == ConnectionState.Open) - { - connection.ChangeDatabase(newDatabaseName); - } - } - info.ConnectionDetails.DatabaseName = newDatabaseName; + foreach (string key in info.AllConnectionTypes) + { + DbConnection conn; + info.TryGetConnection(key, out conn); + if (conn != null && conn.Database != newDatabaseName && conn.State == ConnectionState.Open) + { + if (info.IsCloud && force) + { + conn.Close(); + conn.Dispose(); + info.RemoveConnection(key); + + string connectionString = BuildConnectionString(info.ConnectionDetails); + + // create a sql connection instance + DbConnection connection = info.Factory.CreateSqlConnection(connectionString); + connection.Open(); + info.AddConnection(key, connection); + } + else + { + conn.ChangeDatabase(newDatabaseName); + } + } + + } + // Fire a connection changed event ConnectionChangedParams parameters = new ConnectionChangedParams(); IConnectionSummary summary = info.ConnectionDetails; diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index 282b110a..982f96c2 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -1307,5 +1307,152 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters); Assert.Equal(databaseName, connectionResult.ConnectionSummary.DatabaseName); } + + [Fact] + public async Task CanChangeDatabase() + { + string ownerUri = "file://my/sample/file.sql"; + const string masterDbName = "master"; + const string otherDbName = "other"; + // Given a connection that returns the database name + var connection = new TestSqlConnection(null); + + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns((string connString) => + { + connection.ConnectionString = connString; + SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); + + // Database name is respected. Follow heuristic where empty DB name really means Master + var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; + connection.SetDatabase(dbName); + return connection; + }); + + var connectionService = new ConnectionService(mockFactory.Object); + + // When I connect to default + var connectionResult = await + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + connection.SetState(ConnectionState.Open); + + connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName); + + Assert.Equal(otherDbName, connection.Database); + } + + [Fact] + public async Task CanChangeDatabaseAzure() + { + + string ownerUri = "file://my/sample/file.sql"; + const string masterDbName = "master"; + const string otherDbName = "other"; + string dbName = masterDbName; + // Given a connection that returns the database name + var mockConnection = new Mock(); + mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny())) + .Throws(new Exception()); + mockConnection.SetupGet(conn => conn.Database).Returns(dbName); + mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open); + mockConnection.Setup(conn => conn.Close()); + mockConnection.Setup(conn => conn.Open()); + + var connection = mockConnection.Object; + + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns((string connString) => + { + connection.ConnectionString = connString; + SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); + + // Database name is respected. Follow heuristic where empty DB name really means Master + dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog; + return connection; + }); + + var connectionService = new ConnectionService(mockFactory.Object); + + // When I connect to default + var connectionResult = await + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + ConnectionInfo testInfo; + connectionService.TryFindConnection(ownerUri, out testInfo); + + Assert.NotNull(testInfo); + + testInfo.IsCloud = true; + + connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName, true); + + Assert.Equal(otherDbName, dbName); + } + + [Fact] + public async Task ReturnsFalseIfNotForced() + { + string ownerUri = "file://my/sample/file.sql"; + const string defaultDbName = "databaseName"; + const string otherDbName = "other"; + string dbName = defaultDbName; + // Given a connection that returns the database name + var mockConnection = new Mock(); + mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny())) + .Throws(new Exception()); + mockConnection.SetupGet(conn => conn.Database).Returns(dbName); + mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open); + mockConnection.Setup(conn => conn.Close()); + mockConnection.Setup(conn => conn.Open()); + + var connection = mockConnection.Object; + + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns((string connString) => + { + connection.ConnectionString = connString; + SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString); + + // Database name is respected. Follow heuristic where empty DB name really means Master + dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? defaultDbName : scsb.InitialCatalog; + return connection; + }); + + var connectionService = new ConnectionService(mockFactory.Object); + + // When I connect to default + var connectionResult = await + connectionService + .Connect(new ConnectParams() + { + OwnerUri = ownerUri, + Connection = TestObjects.GetTestConnectionDetails() + }); + + ConnectionInfo testInfo; + connectionService.TryFindConnection(ownerUri, out testInfo); + + Assert.NotNull(testInfo); + + testInfo.IsCloud = true; + + Assert.False(connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName)); + + Assert.Equal(defaultDbName, dbName); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs index 3afb7b50..112e0a8c 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestObjects.cs @@ -194,8 +194,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility public class TestSqlConnection : DbConnection { private string _database; + private ConnectionState _state; - internal TestSqlConnection(TestResultSet[] data) + public TestSqlConnection() + { + + } + + public TestSqlConnection(TestResultSet[] data) { Data = data; } @@ -227,7 +233,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility get { return _database; } } - public override ConnectionState State { get; } + public override ConnectionState State + { + get { return _state; } + } + public override string DataSource { get; } public override string ServerVersion { get; } @@ -238,7 +248,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility public override void ChangeDatabase(string databaseName) { - // No Op + _database = databaseName; } /// @@ -249,6 +259,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { this._database = database; } + + public void SetState(ConnectionState state) + { + this._state = state; + } } ///