mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-14 01:25:40 -05:00
Add force change database logic (#519)
* add logic to close connections if changing fails * added logic to close connections and reopen that fail to change (azure) * expose connection map in connection info, change while to foreach * removed unneeded code * reworked logic to not depend on thrown errors * added tests
This commit is contained in:
@@ -1201,7 +1201,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
ChangeDatabaseParams changeDatabaseParams,
|
||||
RequestContext<bool> requestContext)
|
||||
{
|
||||
await requestContext.SendResult(ChangeConnectionDatabaseContext(changeDatabaseParams.OwnerUri, changeDatabaseParams.NewDatabase));
|
||||
await requestContext.SendResult(ChangeConnectionDatabaseContext(changeDatabaseParams.OwnerUri, changeDatabaseParams.NewDatabase, true));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -1209,23 +1209,42 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
|
||||
/// </summary>
|
||||
/// <param name="ownerUri">URI of the owner of the connection</param>
|
||||
/// <param name="newDatabaseName">Name of the database to change the connection to</param>
|
||||
public bool ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName)
|
||||
public bool ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName, bool force = false)
|
||||
{
|
||||
ConnectionInfo info;
|
||||
if (TryFindConnection(ownerUri, out info))
|
||||
{
|
||||
try
|
||||
{
|
||||
foreach (DbConnection connection in info.AllConnections)
|
||||
{
|
||||
if (connection.State == ConnectionState.Open)
|
||||
{
|
||||
connection.ChangeDatabase(newDatabaseName);
|
||||
}
|
||||
}
|
||||
|
||||
info.ConnectionDetails.DatabaseName = newDatabaseName;
|
||||
|
||||
foreach (string key in info.AllConnectionTypes)
|
||||
{
|
||||
DbConnection conn;
|
||||
info.TryGetConnection(key, out conn);
|
||||
if (conn != null && conn.Database != newDatabaseName && conn.State == ConnectionState.Open)
|
||||
{
|
||||
if (info.IsCloud && force)
|
||||
{
|
||||
conn.Close();
|
||||
conn.Dispose();
|
||||
info.RemoveConnection(key);
|
||||
|
||||
string connectionString = BuildConnectionString(info.ConnectionDetails);
|
||||
|
||||
// create a sql connection instance
|
||||
DbConnection connection = info.Factory.CreateSqlConnection(connectionString);
|
||||
connection.Open();
|
||||
info.AddConnection(key, connection);
|
||||
}
|
||||
else
|
||||
{
|
||||
conn.ChangeDatabase(newDatabaseName);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fire a connection changed event
|
||||
ConnectionChangedParams parameters = new ConnectionChangedParams();
|
||||
IConnectionSummary summary = info.ConnectionDetails;
|
||||
|
||||
@@ -1307,5 +1307,152 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
|
||||
var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters);
|
||||
Assert.Equal(databaseName, connectionResult.ConnectionSummary.DatabaseName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CanChangeDatabase()
|
||||
{
|
||||
string ownerUri = "file://my/sample/file.sql";
|
||||
const string masterDbName = "master";
|
||||
const string otherDbName = "other";
|
||||
// Given a connection that returns the database name
|
||||
var connection = new TestSqlConnection(null);
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
.Returns((string connString) =>
|
||||
{
|
||||
connection.ConnectionString = connString;
|
||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||
|
||||
// Database name is respected. Follow heuristic where empty DB name really means Master
|
||||
var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog;
|
||||
connection.SetDatabase(dbName);
|
||||
return connection;
|
||||
});
|
||||
|
||||
var connectionService = new ConnectionService(mockFactory.Object);
|
||||
|
||||
// When I connect to default
|
||||
var connectionResult = await
|
||||
connectionService
|
||||
.Connect(new ConnectParams()
|
||||
{
|
||||
OwnerUri = ownerUri,
|
||||
Connection = TestObjects.GetTestConnectionDetails()
|
||||
});
|
||||
|
||||
connection.SetState(ConnectionState.Open);
|
||||
|
||||
connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName);
|
||||
|
||||
Assert.Equal(otherDbName, connection.Database);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CanChangeDatabaseAzure()
|
||||
{
|
||||
|
||||
string ownerUri = "file://my/sample/file.sql";
|
||||
const string masterDbName = "master";
|
||||
const string otherDbName = "other";
|
||||
string dbName = masterDbName;
|
||||
// Given a connection that returns the database name
|
||||
var mockConnection = new Mock<DbConnection>();
|
||||
mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny<string>()))
|
||||
.Throws(new Exception());
|
||||
mockConnection.SetupGet(conn => conn.Database).Returns(dbName);
|
||||
mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open);
|
||||
mockConnection.Setup(conn => conn.Close());
|
||||
mockConnection.Setup(conn => conn.Open());
|
||||
|
||||
var connection = mockConnection.Object;
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
.Returns((string connString) =>
|
||||
{
|
||||
connection.ConnectionString = connString;
|
||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||
|
||||
// Database name is respected. Follow heuristic where empty DB name really means Master
|
||||
dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog;
|
||||
return connection;
|
||||
});
|
||||
|
||||
var connectionService = new ConnectionService(mockFactory.Object);
|
||||
|
||||
// When I connect to default
|
||||
var connectionResult = await
|
||||
connectionService
|
||||
.Connect(new ConnectParams()
|
||||
{
|
||||
OwnerUri = ownerUri,
|
||||
Connection = TestObjects.GetTestConnectionDetails()
|
||||
});
|
||||
|
||||
ConnectionInfo testInfo;
|
||||
connectionService.TryFindConnection(ownerUri, out testInfo);
|
||||
|
||||
Assert.NotNull(testInfo);
|
||||
|
||||
testInfo.IsCloud = true;
|
||||
|
||||
connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName, true);
|
||||
|
||||
Assert.Equal(otherDbName, dbName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ReturnsFalseIfNotForced()
|
||||
{
|
||||
string ownerUri = "file://my/sample/file.sql";
|
||||
const string defaultDbName = "databaseName";
|
||||
const string otherDbName = "other";
|
||||
string dbName = defaultDbName;
|
||||
// Given a connection that returns the database name
|
||||
var mockConnection = new Mock<DbConnection>();
|
||||
mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny<string>()))
|
||||
.Throws(new Exception());
|
||||
mockConnection.SetupGet(conn => conn.Database).Returns(dbName);
|
||||
mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open);
|
||||
mockConnection.Setup(conn => conn.Close());
|
||||
mockConnection.Setup(conn => conn.Open());
|
||||
|
||||
var connection = mockConnection.Object;
|
||||
|
||||
var mockFactory = new Mock<ISqlConnectionFactory>();
|
||||
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
|
||||
.Returns((string connString) =>
|
||||
{
|
||||
connection.ConnectionString = connString;
|
||||
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
|
||||
|
||||
// Database name is respected. Follow heuristic where empty DB name really means Master
|
||||
dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? defaultDbName : scsb.InitialCatalog;
|
||||
return connection;
|
||||
});
|
||||
|
||||
var connectionService = new ConnectionService(mockFactory.Object);
|
||||
|
||||
// When I connect to default
|
||||
var connectionResult = await
|
||||
connectionService
|
||||
.Connect(new ConnectParams()
|
||||
{
|
||||
OwnerUri = ownerUri,
|
||||
Connection = TestObjects.GetTestConnectionDetails()
|
||||
});
|
||||
|
||||
ConnectionInfo testInfo;
|
||||
connectionService.TryFindConnection(ownerUri, out testInfo);
|
||||
|
||||
Assert.NotNull(testInfo);
|
||||
|
||||
testInfo.IsCloud = true;
|
||||
|
||||
Assert.False(connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName));
|
||||
|
||||
Assert.Equal(defaultDbName, dbName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,8 +194,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
|
||||
public class TestSqlConnection : DbConnection
|
||||
{
|
||||
private string _database;
|
||||
private ConnectionState _state;
|
||||
|
||||
internal TestSqlConnection(TestResultSet[] data)
|
||||
public TestSqlConnection()
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
public TestSqlConnection(TestResultSet[] data)
|
||||
{
|
||||
Data = data;
|
||||
}
|
||||
@@ -227,7 +233,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
|
||||
get { return _database; }
|
||||
}
|
||||
|
||||
public override ConnectionState State { get; }
|
||||
public override ConnectionState State
|
||||
{
|
||||
get { return _state; }
|
||||
}
|
||||
|
||||
public override string DataSource { get; }
|
||||
public override string ServerVersion { get; }
|
||||
|
||||
@@ -238,7 +248,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
|
||||
|
||||
public override void ChangeDatabase(string databaseName)
|
||||
{
|
||||
// No Op
|
||||
_database = databaseName;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -249,6 +259,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
|
||||
{
|
||||
this._database = database;
|
||||
}
|
||||
|
||||
public void SetState(ConnectionState state)
|
||||
{
|
||||
this._state = state;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
||||
Reference in New Issue
Block a user