Feature/connect cancel (#74)

* Implemented connection cancellation

* Made connect requests return immediately and created a separate connection complete notification

* Fix spelling

* Fix sorting

* Add separate lock for cancellation source map
This commit is contained in:
Mitchell Sternke
2016-10-04 15:45:52 -07:00
committed by GitHub
parent 62525b9c98
commit 8408bc6dff
17 changed files with 532 additions and 98 deletions

View File

@@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
@@ -52,6 +53,214 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
return connectionMock.Object;
}
[Fact]
public void CanCancelConnectRequest()
{
var testFile = "file:///my/test/file.sql";
// Given a connection that times out and responds to cancellation
var mockConnection = new Mock<DbConnection> { CallBase = true };
CancellationToken token;
bool ready = false;
mockConnection.Setup(x => x.OpenAsync(Moq.It.IsAny<CancellationToken>()))
.Callback<CancellationToken>(t =>
{
// Pass the token to the return handler and signal the main thread to cancel
token = t;
ready = true;
})
.Returns(() =>
{
if (TestUtils.WaitFor(() => token.IsCancellationRequested))
{
throw new OperationCanceledException();
}
else
{
return Task.FromResult(true);
}
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns(mockConnection.Object);
var connectionService = new ConnectionService(mockFactory.Object);
// Connect the connection asynchronously in a background thread
var connectionDetails = TestObjects.GetTestConnectionDetails();
var connectTask = Task.Run(async () =>
{
return await connectionService
.Connect(new ConnectParams()
{
OwnerUri = testFile,
Connection = connectionDetails
});
});
// Wait for the connection to call OpenAsync()
Assert.True(TestUtils.WaitFor(() => ready));
// Send a cancellation request
var cancelResult = connectionService
.CancelConnect(new CancelConnectParams()
{
OwnerUri = testFile
});
// Wait for the connection task to finish
connectTask.Wait();
// Verify that the connection was cancelled (no connection was created)
Assert.Null(connectTask.Result.ConnectionId);
// Verify that the cancel succeeded
Assert.True(cancelResult);
}
[Fact]
public async void CanCancelConnectRequestByConnecting()
{
var testFile = "file:///my/test/file.sql";
// Given a connection that times out and responds to cancellation
var mockConnection = new Mock<DbConnection> { CallBase = true };
CancellationToken token;
bool ready = false;
mockConnection.Setup(x => x.OpenAsync(Moq.It.IsAny<CancellationToken>()))
.Callback<CancellationToken>(t =>
{
// Pass the token to the return handler and signal the main thread to cancel
token = t;
ready = true;
})
.Returns(() =>
{
if (TestUtils.WaitFor(() => token.IsCancellationRequested))
{
throw new OperationCanceledException();
}
else
{
return Task.FromResult(true);
}
});
// Given a second connection that succeeds
var mockConnection2 = new Mock<DbConnection> { CallBase = true };
mockConnection2.Setup(x => x.OpenAsync(Moq.It.IsAny<CancellationToken>()))
.Returns(() => Task.Run(() => {}));
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns(mockConnection.Object)
.Returns(mockConnection2.Object);
var connectionService = new ConnectionService(mockFactory.Object);
// Connect the first connection asynchronously in a background thread
var connectionDetails = TestObjects.GetTestConnectionDetails();
var connectTask = Task.Run(async () =>
{
return await connectionService
.Connect(new ConnectParams()
{
OwnerUri = testFile,
Connection = connectionDetails
});
});
// Wait for the connection to call OpenAsync()
Assert.True(TestUtils.WaitFor(() => ready));
// Send a cancellation by trying to connect again
var connectResult = await connectionService
.Connect(new ConnectParams()
{
OwnerUri = testFile,
Connection = connectionDetails
});
// Wait for the first connection task to finish
connectTask.Wait();
// Verify that the first connection was cancelled (no connection was created)
Assert.Null(connectTask.Result.ConnectionId);
// Verify that the second connection succeeded
Assert.NotEmpty(connectResult.ConnectionId);
}
[Fact]
public void CanCancelConnectRequestByDisconnecting()
{
var testFile = "file:///my/test/file.sql";
// Given a connection that times out and responds to cancellation
var mockConnection = new Mock<DbConnection> { CallBase = true };
CancellationToken token;
bool ready = false;
mockConnection.Setup(x => x.OpenAsync(Moq.It.IsAny<CancellationToken>()))
.Callback<CancellationToken>(t =>
{
// Pass the token to the return handler and signal the main thread to cancel
token = t;
ready = true;
})
.Returns(() =>
{
if (TestUtils.WaitFor(() => token.IsCancellationRequested))
{
throw new OperationCanceledException();
}
else
{
return Task.FromResult(true);
}
});
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>()))
.Returns(mockConnection.Object);
var connectionService = new ConnectionService(mockFactory.Object);
// Connect the first connection asynchronously in a background thread
var connectionDetails = TestObjects.GetTestConnectionDetails();
var connectTask = Task.Run(async () =>
{
return await connectionService
.Connect(new ConnectParams()
{
OwnerUri = testFile,
Connection = connectionDetails
});
});
// Wait for the connection to call OpenAsync()
Assert.True(TestUtils.WaitFor(() => ready));
// Send a cancellation by trying to disconnect
var disconnectResult = connectionService
.Disconnect(new DisconnectParams()
{
OwnerUri = testFile
});
// Wait for the first connection task to finish
connectTask.Wait();
// Verify that the first connection was cancelled (no connection was created)
Assert.Null(connectTask.Result.ConnectionId);
// Verify that the disconnect failed (since it caused a cancellation)
Assert.False(disconnectResult);
}
/// <summary>
/// Verify that we can connect to the default database when no database name is
/// provided as a parameter.
@@ -59,12 +268,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
[Theory]
[InlineDataAttribute(null)]
[InlineDataAttribute("")]
public void CanConnectWithEmptyDatabaseName(string databaseName)
public async void CanConnectWithEmptyDatabaseName(string databaseName)
{
// Connect
var connectionDetails = TestObjects.GetTestConnectionDetails();
connectionDetails.DatabaseName = databaseName;
var connectionResult =
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(new ConnectParams()
{
@@ -83,7 +292,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
[Theory]
[InlineDataAttribute("master")]
[InlineDataAttribute("nonMasterDb")]
public void ConnectToDefaultDatabaseRespondsWithActualDbName(string expectedDbName)
public async void ConnectToDefaultDatabaseRespondsWithActualDbName(string expectedDbName)
{
// Given connecting with empty database name will return the expected DB name
var connectionMock = new Mock<DbConnection> { CallBase = true };
@@ -99,7 +308,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
var connectionDetails = TestObjects.GetTestConnectionDetails();
connectionDetails.DatabaseName = string.Empty;
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -118,14 +327,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// connection, we disconnect first before connecting.
/// </summary>
[Fact]
public void ConnectingWhenConnectionExistCausesDisconnectThenConnect()
public async void ConnectingWhenConnectionExistCausesDisconnectThenConnect()
{
bool callbackInvoked = false;
// first connect
string ownerUri = "file://my/sample/file.sql";
var connectionService = TestObjects.GetTestConnectionService();
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -146,7 +355,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
);
// send annother connect request
connectionResult =
connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -165,7 +374,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verify that when connecting with invalid credentials, an error is thrown.
/// </summary>
[Fact]
public void ConnectingWithInvalidCredentialsYieldsErrorMessage()
public async void ConnectingWithInvalidCredentialsYieldsErrorMessage()
{
var testConnectionDetails = TestObjects.GetTestConnectionDetails();
var invalidConnectionDetails = new ConnectionDetails();
@@ -175,7 +384,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
invalidConnectionDetails.Password = "invalidPassword";
// Connect to test db with invalid credentials
var connectionResult =
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(new ConnectParams()
{
@@ -204,10 +413,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
[InlineData("Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456")]
[InlineData("Integrated", "", "my-server", "test", "sa", "123456")]
[InlineData("Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456")]
public void ConnectingWithInvalidParametersYieldsErrorMessage(string authType, string ownerUri, string server, string database, string userName, string password)
public async void ConnectingWithInvalidParametersYieldsErrorMessage(string authType, string ownerUri, string server, string database, string userName, string password)
{
// Connect with invalid parameters
var connectionResult =
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(new ConnectParams()
{
@@ -238,10 +447,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
[InlineData("sa", "")]
[InlineData(null, "12345678")]
[InlineData("", "12345678")]
public void ConnectingWithNoUsernameOrPasswordWorksForIntegratedAuth(string userName, string password)
public async void ConnectingWithNoUsernameOrPasswordWorksForIntegratedAuth(string userName, string password)
{
// Connect
var connectionResult =
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(new ConnectParams()
{
@@ -263,10 +472,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verify that when connecting with a null parameters object, an error is thrown.
/// </summary>
[Fact]
public void ConnectingWithNullParametersObjectYieldsErrorMessage()
public async void ConnectingWithNullParametersObjectYieldsErrorMessage()
{
// Connect with null parameters
var connectionResult =
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(null);
@@ -330,7 +539,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verify that a connection changed event is fired when the database context changes.
/// </summary>
[Fact]
public void ConnectionChangedEventIsFiredWhenDatabaseContextChanges()
public async void ConnectionChangedEventIsFiredWhenDatabaseContextChanges()
{
var serviceHostMock = new Mock<IProtocolEndpoint>();
@@ -339,7 +548,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
// Set up an initial connection
string ownerUri = "file://my/sample/file.sql";
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -364,11 +573,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verify that the SQL parser correctly detects errors in text
/// </summary>
[Fact]
public void ConnectToDatabaseTest()
public async void ConnectToDatabaseTest()
{
// connect to a database instance
string ownerUri = "file://my/sample/file.sql";
var connectionResult =
var connectionResult = await
TestObjects.GetTestConnectionService()
.Connect(new ConnectParams()
{
@@ -384,12 +593,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verify that we can disconnect from an active connection succesfully
/// </summary>
[Fact]
public void DisconnectFromDatabaseTest()
public async void DisconnectFromDatabaseTest()
{
// first connect
string ownerUri = "file://my/sample/file.sql";
var connectionService = TestObjects.GetTestConnectionService();
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -414,14 +623,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Test that when a disconnect is performed, the callback event is fired
/// </summary>
[Fact]
public void DisconnectFiresCallbackEvent()
public async void DisconnectFiresCallbackEvent()
{
bool callbackInvoked = false;
// first connect
string ownerUri = "file://my/sample/file.sql";
var connectionService = TestObjects.GetTestConnectionService();
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -458,12 +667,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Test that disconnecting an active connection removes the Owner URI -> ConnectionInfo mapping
/// </summary>
[Fact]
public void DisconnectRemovesOwnerMapping()
public async void DisconnectRemovesOwnerMapping()
{
// first connect
string ownerUri = "file://my/sample/file.sql";
var connectionService = TestObjects.GetTestConnectionService();
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -498,12 +707,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
[InlineDataAttribute(null)]
[InlineDataAttribute("")]
public void DisconnectValidatesParameters(string disconnectUri)
public async void DisconnectValidatesParameters(string disconnectUri)
{
// first connect
string ownerUri = "file://my/sample/file.sql";
var connectionService = TestObjects.GetTestConnectionService();
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -530,7 +739,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verifies the the list databases operation lists database names for the server used by a connection.
/// </summary>
[Fact]
public void ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames()
public async void ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames()
{
// Result set for the query of database names
Dictionary<string, string>[] data =
@@ -550,7 +759,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
// connect to a database instance
string ownerUri = "file://my/sample/file.sql";
var connectionResult =
var connectionResult = await
connectionService
.Connect(new ConnectParams()
{
@@ -579,7 +788,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verify that the SQL parser correctly detects errors in text
/// </summary>
[Fact]
public void OnConnectionCallbackHandlerTest()
public async void OnConnectionCallbackHandlerTest()
{
bool callbackInvoked = false;
@@ -593,7 +802,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
);
// connect to a database instance
var connectionResult = connectionService.Connect(TestObjects.GetTestConnectionParams());
var connectionResult = await connectionService.Connect(TestObjects.GetTestConnectionParams());
// verify that a valid connection id was returned
Assert.True(callbackInvoked);
@@ -603,14 +812,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection
/// Verify when a connection is created that the URI -> Connection mapping is created in the connection service.
/// </summary>
[Fact]
public void TestConnectRequestRegistersOwner()
public async void TestConnectRequestRegistersOwner()
{
// Given a request to connect to a database
var service = TestObjects.GetTestConnectionService();
var connectParams = TestObjects.GetTestConnectionParams();
// connect to a database instance
var connectionResult = service.Connect(connectParams);
var connectionResult = await service.Connect(connectParams);
// verify that a valid connection id was returned
Assert.NotNull(connectionResult.ConnectionId);