Add force change database logic (#519)

* add logic to close connections if changing fails

* added logic to close connections and reopen that fail to change (azure)

* expose connection map in connection info, change while to foreach

* removed unneeded code

* reworked logic to not depend on thrown errors

* added tests
This commit is contained in:
Anthony Dresser
2017-10-25 10:54:17 -07:00
committed by GitHub
parent f80fd8a458
commit 399b03cbd1
3 changed files with 194 additions and 13 deletions

View File

@@ -1201,7 +1201,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
ChangeDatabaseParams changeDatabaseParams,
RequestContext<bool> requestContext)
{
await requestContext.SendResult(ChangeConnectionDatabaseContext(changeDatabaseParams.OwnerUri, changeDatabaseParams.NewDatabase));
await requestContext.SendResult(ChangeConnectionDatabaseContext(changeDatabaseParams.OwnerUri, changeDatabaseParams.NewDatabase, true));
}
/// <summary>
@@ -1209,23 +1209,42 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
/// </summary>
/// <param name="ownerUri">URI of the owner of the connection</param>
/// <param name="newDatabaseName">Name of the database to change the connection to</param>
public bool ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName)
public bool ChangeConnectionDatabaseContext(string ownerUri, string newDatabaseName, bool force = false)
{
ConnectionInfo info;
if (TryFindConnection(ownerUri, out info))
{
try
{
foreach (DbConnection connection in info.AllConnections)
{
if (connection.State == ConnectionState.Open)
{
connection.ChangeDatabase(newDatabaseName);
}
}
info.ConnectionDetails.DatabaseName = newDatabaseName;
foreach (string key in info.AllConnectionTypes)
{
DbConnection conn;
info.TryGetConnection(key, out conn);
if (conn != null && conn.Database != newDatabaseName && conn.State == ConnectionState.Open)
{
if (info.IsCloud && force)
{
conn.Close();
conn.Dispose();
info.RemoveConnection(key);
string connectionString = BuildConnectionString(info.ConnectionDetails);
// create a sql connection instance
DbConnection connection = info.Factory.CreateSqlConnection(connectionString);
connection.Open();
info.AddConnection(key, connection);
}
else
{
conn.ChangeDatabase(newDatabaseName);
}
}
}
// Fire a connection changed event
ConnectionChangedParams parameters = new ConnectionChangedParams();
IConnectionSummary summary = info.ConnectionDetails;

View File

@@ -1307,5 +1307,152 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
var connectionResult = await TestObjects.GetTestConnectionService().Connect(connectionParameters);
Assert.Equal(databaseName, connectionResult.ConnectionSummary.DatabaseName);
}
[Fact]
public async Task CanChangeDatabase()
{
string ownerUri = "file://my/sample/file.sql";
const string masterDbName = "master";
const string otherDbName = "other";
// Given a connection that returns the database name
var connection = new TestSqlConnection(null);
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
// Database name is respected. Follow heuristic where empty DB name really means Master
var dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog;
connection.SetDatabase(dbName);
return connection;
});
var connectionService = new ConnectionService(mockFactory.Object);
// When I connect to default
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
OwnerUri = ownerUri,
Connection = TestObjects.GetTestConnectionDetails()
});
connection.SetState(ConnectionState.Open);
connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName);
Assert.Equal(otherDbName, connection.Database);
}
[Fact]
public async Task CanChangeDatabaseAzure()
{
string ownerUri = "file://my/sample/file.sql";
const string masterDbName = "master";
const string otherDbName = "other";
string dbName = masterDbName;
// Given a connection that returns the database name
var mockConnection = new Mock<DbConnection>();
mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny<string>()))
.Throws(new Exception());
mockConnection.SetupGet(conn => conn.Database).Returns(dbName);
mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open);
mockConnection.Setup(conn => conn.Close());
mockConnection.Setup(conn => conn.Open());
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
// Database name is respected. Follow heuristic where empty DB name really means Master
dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? masterDbName : scsb.InitialCatalog;
return connection;
});
var connectionService = new ConnectionService(mockFactory.Object);
// When I connect to default
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
OwnerUri = ownerUri,
Connection = TestObjects.GetTestConnectionDetails()
});
ConnectionInfo testInfo;
connectionService.TryFindConnection(ownerUri, out testInfo);
Assert.NotNull(testInfo);
testInfo.IsCloud = true;
connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName, true);
Assert.Equal(otherDbName, dbName);
}
[Fact]
public async Task ReturnsFalseIfNotForced()
{
string ownerUri = "file://my/sample/file.sql";
const string defaultDbName = "databaseName";
const string otherDbName = "other";
string dbName = defaultDbName;
// Given a connection that returns the database name
var mockConnection = new Mock<DbConnection>();
mockConnection.Setup(conn => conn.ChangeDatabase(It.IsAny<string>()))
.Throws(new Exception());
mockConnection.SetupGet(conn => conn.Database).Returns(dbName);
mockConnection.SetupGet(conn => conn.State).Returns(ConnectionState.Open);
mockConnection.Setup(conn => conn.Close());
mockConnection.Setup(conn => conn.Open());
var connection = mockConnection.Object;
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns((string connString) =>
{
connection.ConnectionString = connString;
SqlConnectionStringBuilder scsb = new SqlConnectionStringBuilder(connString);
// Database name is respected. Follow heuristic where empty DB name really means Master
dbName = string.IsNullOrEmpty(scsb.InitialCatalog) ? defaultDbName : scsb.InitialCatalog;
return connection;
});
var connectionService = new ConnectionService(mockFactory.Object);
// When I connect to default
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
OwnerUri = ownerUri,
Connection = TestObjects.GetTestConnectionDetails()
});
ConnectionInfo testInfo;
connectionService.TryFindConnection(ownerUri, out testInfo);
Assert.NotNull(testInfo);
testInfo.IsCloud = true;
Assert.False(connectionService.ChangeConnectionDatabaseContext(ownerUri, otherDbName));
Assert.Equal(defaultDbName, dbName);
}
}
}

View File

@@ -194,8 +194,14 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
public class TestSqlConnection : DbConnection
{
private string _database;
private ConnectionState _state;
internal TestSqlConnection(TestResultSet[] data)
public TestSqlConnection()
{
}
public TestSqlConnection(TestResultSet[] data)
{
Data = data;
}
@@ -227,7 +233,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
get { return _database; }
}
public override ConnectionState State { get; }
public override ConnectionState State
{
get { return _state; }
}
public override string DataSource { get; }
public override string ServerVersion { get; }
@@ -238,7 +248,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
public override void ChangeDatabase(string databaseName)
{
// No Op
_database = databaseName;
}
/// <summary>
@@ -249,6 +259,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
{
this._database = database;
}
public void SetState(ConnectionState state)
{
this._state = state;
}
}
/// <summary>