diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index 80707a55..fd06343c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -4,10 +4,12 @@ // using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; +using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; @@ -47,6 +49,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection private Dictionary ownerToConnectionMap = new Dictionary(); + private ConcurrentDictionary ownerToCancellationTokenSourceMap = new ConcurrentDictionary(); + + private Object cancellationTokenSourceLock = new Object(); + /// /// Map from script URIs to ConnectionInfo objects /// This is internal for testing access only @@ -131,21 +137,22 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// Open a connection with the specified connection details /// /// - public ConnectResponse Connect(ConnectParams connectionParams) + public async Task Connect(ConnectParams connectionParams) { // Validate parameters string paramValidationErrorMessage; if (connectionParams == null) { - return new ConnectResponse + return new ConnectionCompleteParams { Messages = SR.ConnectionServiceConnectErrorNullParams }; } if (!connectionParams.IsValid(out paramValidationErrorMessage)) { - return new ConnectResponse + return new ConnectionCompleteParams { + OwnerUri = connectionParams.OwnerUri, Messages = paramValidationErrorMessage }; } @@ -164,7 +171,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection connectionInfo = new ConnectionInfo(ConnectionFactory, connectionParams.OwnerUri, connectionParams.Connection); // try to connect - var response = new ConnectResponse(); + var response = new ConnectionCompleteParams(); + response.OwnerUri = connectionParams.OwnerUri; + CancellationTokenSource source = null; try { // build the connection string from the input parameters @@ -177,7 +186,36 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // we'll remove this once ConnectionService is refactored to not own the LanguageService connection connectionInfo.ConnectionDetails.MultipleActiveResultSets = true; - connectionInfo.SqlConnection.Open(); + // Add a cancellation token source so that the connection OpenAsync() can be cancelled + using (source = new CancellationTokenSource()) + { + // Locking here to perform two operations as one atomic operation + lock (cancellationTokenSourceLock) + { + // If the URI is currently connecting from a different request, cancel it before we try to connect + CancellationTokenSource currentSource; + if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out currentSource)) + { + currentSource.Cancel(); + } + ownerToCancellationTokenSourceMap[connectionParams.OwnerUri] = source; + } + + // Create a task to handle cancellation requests + var cancellationTask = Task.Run(() => + { + source.Token.WaitHandle.WaitOne(); + source.Token.ThrowIfCancellationRequested(); + }); + + var openTask = Task.Run(async () => { + await connectionInfo.SqlConnection.OpenAsync(source.Token); + }); + + // Open the connection + await Task.WhenAny(openTask, cancellationTask).Unwrap(); + source.Cancel(); + } } catch (SqlException ex) { @@ -186,12 +224,32 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection response.Messages = ex.ToString(); return response; } + catch (OperationCanceledException) + { + // OpenAsync was cancelled + response.Messages = SR.ConnectionServiceConnectionCanceled; + return response; + } catch (Exception ex) { response.ErrorMessage = ex.Message; response.Messages = ex.ToString(); return response; } + finally + { + // Remove our cancellation token from the map since we're no longer connecting + // Using a lock here to perform two operations as one atomic operation + lock (cancellationTokenSourceLock) + { + // Only remove the token from the map if it is the same one created by this request + CancellationTokenSource sourceValue; + if (ownerToCancellationTokenSourceMap.TryGetValue(connectionParams.OwnerUri, out sourceValue) && sourceValue == source) + { + ownerToCancellationTokenSourceMap.TryRemove(connectionParams.OwnerUri, out sourceValue); + } + } + } ownerToConnectionMap[connectionParams.OwnerUri] = connectionInfo; @@ -208,7 +266,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // invoke callback notifications foreach (var activity in this.onConnectionActivities) { - activity(connectionInfo); + await activity(connectionInfo); } // try to get information about the connected SQL Server instance @@ -242,6 +300,37 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return response; } + /// + /// Cancel a connection that is in the process of opening. + /// + public bool CancelConnect(CancelConnectParams cancelParams) + { + // Validate parameters + if (cancelParams == null || string.IsNullOrEmpty(cancelParams.OwnerUri)) + { + return false; + } + + // Cancel any current connection attempts for this URI + CancellationTokenSource source; + if (ownerToCancellationTokenSourceMap.TryGetValue(cancelParams.OwnerUri, out source)) + { + try + { + source.Cancel(); + return true; + } + catch + { + return false; + } + } + else + { + return false; + } + } + /// /// Close a connection with the specified connection details. /// @@ -253,6 +342,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection return false; } + // Cancel if we are in the middle of connecting + if (CancelConnect(new CancelConnectParams() { OwnerUri = disconnectParams.OwnerUri })) + { + return false; + } + // Lookup the connection owned by the URI ConnectionInfo info; if (!ownerToConnectionMap.TryGetValue(disconnectParams.OwnerUri, out info)) @@ -327,6 +422,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Register request and event handlers with the Service Host serviceHost.SetRequestHandler(ConnectionRequest.Type, HandleConnectRequest); + serviceHost.SetRequestHandler(CancelConnectRequest.Type, HandleCancelConnectRequest); serviceHost.SetRequestHandler(DisconnectRequest.Type, HandleDisconnectRequest); serviceHost.SetRequestHandler(ListDatabasesRequest.Type, HandleListDatabasesRequest); @@ -359,14 +455,50 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection /// protected async Task HandleConnectRequest( ConnectParams connectParams, - RequestContext requestContext) + RequestContext requestContext) { Logger.Write(LogLevel.Verbose, "HandleConnectRequest"); try { - // open connection base on request details - ConnectResponse result = ConnectionService.Instance.Connect(connectParams); + // create a task to connect asyncronously so that other requests are not blocked in the meantime + Task.Run(async () => + { + try + { + // open connection based on request details + ConnectionCompleteParams result = await ConnectionService.Instance.Connect(connectParams); + await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); + } + catch (Exception ex) + { + ConnectionCompleteParams result = new ConnectionCompleteParams() + { + Messages = ex.ToString() + }; + await ServiceHost.SendEvent(ConnectionCompleteNotification.Type, result); + } + }); + await requestContext.SendResult(true); + } + catch + { + await requestContext.SendResult(false); + } + } + + /// + /// Handle cancel connect requests + /// + protected async Task HandleCancelConnectRequest( + CancelConnectParams cancelParams, + RequestContext requestContext) + { + Logger.Write(LogLevel.Verbose, "HandleCancelConnectRequest"); + + try + { + bool result = ConnectionService.Instance.CancelConnect(cancelParams); await requestContext.SendResult(result); } catch(Exception ex) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs new file mode 100644 index 00000000..9f2efdb0 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectParams.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Parameters for the Cancel Connect Request. + /// + public class CancelConnectParams + { + /// + /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace + /// or a virtual file representing an object in a database. + /// + public string OwnerUri { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectRequest.cs new file mode 100644 index 00000000..a284f317 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/CancelConnectRequest.cs @@ -0,0 +1,19 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts +{ + /// + /// Cancel connect request mapping entry + /// + public class CancelConnectRequest + { + public static readonly + RequestType Type = + RequestType.Create("connection/cancelconnect"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs similarity index 62% rename from src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs rename to src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs index 9dcf061e..50517a52 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectResponse.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionCompleteNotification.cs @@ -3,13 +3,21 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { /// - /// Message format for the connection result response + /// Parameters to be sent back with a connection complete event /// - public class ConnectResponse + public class ConnectionCompleteParams { + /// + /// A URI identifying the owner of the connection. This will most commonly be a file in the workspace + /// or a virtual file representing an object in a database. + /// + public string OwnerUri { get; set; } + /// /// A GUID representing a unique connection ID /// @@ -40,4 +48,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// public ConnectionSummary ConnectionSummary { get; set; } } + + /// + /// ConnectionComplete notification mapping entry + /// + public class ConnectionCompleteNotification + { + public static readonly + EventType Type = + EventType.Create("connection/complete"); + } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs index 50251e12..74320bdd 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ConnectionRequest.cs @@ -13,7 +13,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts public class ConnectionRequest { public static readonly - RequestType Type = - RequestType.Create("connection/connect"); + RequestType Type = + RequestType.Create("connection/connect"); } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.cs b/src/Microsoft.SqlTools.ServiceLayer/sr.cs index 213c3d55..811ab975 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.cs @@ -45,6 +45,14 @@ namespace Microsoft.SqlTools.ServiceLayer } } + public static string ConnectionServiceConnectionCanceled + { + get + { + return Keys.GetString(Keys.ConnectionServiceConnectionCanceled); + } + } + public static string ConnectionParamsValidateNullOwnerUri { get @@ -368,7 +376,7 @@ namespace Microsoft.SqlTools.ServiceLayer [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] public class Keys { - static ResourceManager resourceManager = new ResourceManager("Microsoft.SqlTools.ServiceLayer.SR", typeof(SR).GetTypeInfo().Assembly); + static ResourceManager resourceManager = new ResourceManager(typeof(SR)); static CultureInfo _culture = null; @@ -388,6 +396,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string ConnectionServiceConnStringInvalidIntent = "ConnectionServiceConnStringInvalidIntent"; + public const string ConnectionServiceConnectionCanceled = "ConnectionServiceConnectionCanceled"; + + public const string ConnectionParamsValidateNullOwnerUri = "ConnectionParamsValidateNullOwnerUri"; diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.resx b/src/Microsoft.SqlTools.ServiceLayer/sr.resx index 3f8e9318..63d7e71b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.resx +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.resx @@ -140,6 +140,10 @@ . Parameters: 0 - intent (string) + + Connection canceled + + OwnerUri cannot be null or empty diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/sr.strings index a9945f20..a74a54d9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.strings @@ -33,6 +33,8 @@ ConnectionServiceConnStringInvalidAuthType(string authType) = Invalid value '{0} ConnectionServiceConnStringInvalidIntent(string intent) = Invalid value '{0}' for ApplicationIntent. Valid values are 'ReadWrite' and 'ReadOnly'. +ConnectionServiceConnectionCanceled = Connection canceled + ###### ### Connection Params Validation Errors diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs index 82b5aca3..8205cf21 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Connection/ConnectionServiceTests.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Reflection; +using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; @@ -52,6 +53,214 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection return connectionMock.Object; } + [Fact] + public void CanCancelConnectRequest() + { + var testFile = "file:///my/test/file.sql"; + + // Given a connection that times out and responds to cancellation + var mockConnection = new Mock { CallBase = true }; + CancellationToken token; + bool ready = false; + mockConnection.Setup(x => x.OpenAsync(Moq.It.IsAny())) + .Callback(t => + { + // Pass the token to the return handler and signal the main thread to cancel + token = t; + ready = true; + }) + .Returns(() => + { + if (TestUtils.WaitFor(() => token.IsCancellationRequested)) + { + throw new OperationCanceledException(); + } + else + { + return Task.FromResult(true); + } + }); + + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(mockConnection.Object); + + + var connectionService = new ConnectionService(mockFactory.Object); + + // Connect the connection asynchronously in a background thread + var connectionDetails = TestObjects.GetTestConnectionDetails(); + var connectTask = Task.Run(async () => + { + return await connectionService + .Connect(new ConnectParams() + { + OwnerUri = testFile, + Connection = connectionDetails + }); + }); + + // Wait for the connection to call OpenAsync() + Assert.True(TestUtils.WaitFor(() => ready)); + + // Send a cancellation request + var cancelResult = connectionService + .CancelConnect(new CancelConnectParams() + { + OwnerUri = testFile + }); + + // Wait for the connection task to finish + connectTask.Wait(); + + // Verify that the connection was cancelled (no connection was created) + Assert.Null(connectTask.Result.ConnectionId); + + // Verify that the cancel succeeded + Assert.True(cancelResult); + } + + [Fact] + public async void CanCancelConnectRequestByConnecting() + { + var testFile = "file:///my/test/file.sql"; + + // Given a connection that times out and responds to cancellation + var mockConnection = new Mock { CallBase = true }; + CancellationToken token; + bool ready = false; + mockConnection.Setup(x => x.OpenAsync(Moq.It.IsAny())) + .Callback(t => + { + // Pass the token to the return handler and signal the main thread to cancel + token = t; + ready = true; + }) + .Returns(() => + { + if (TestUtils.WaitFor(() => token.IsCancellationRequested)) + { + throw new OperationCanceledException(); + } + else + { + return Task.FromResult(true); + } + }); + + // Given a second connection that succeeds + var mockConnection2 = new Mock { CallBase = true }; + mockConnection2.Setup(x => x.OpenAsync(Moq.It.IsAny())) + .Returns(() => Task.Run(() => {})); + + var mockFactory = new Mock(); + mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(mockConnection.Object) + .Returns(mockConnection2.Object); + + + var connectionService = new ConnectionService(mockFactory.Object); + + // Connect the first connection asynchronously in a background thread + var connectionDetails = TestObjects.GetTestConnectionDetails(); + var connectTask = Task.Run(async () => + { + return await connectionService + .Connect(new ConnectParams() + { + OwnerUri = testFile, + Connection = connectionDetails + }); + }); + + // Wait for the connection to call OpenAsync() + Assert.True(TestUtils.WaitFor(() => ready)); + + // Send a cancellation by trying to connect again + var connectResult = await connectionService + .Connect(new ConnectParams() + { + OwnerUri = testFile, + Connection = connectionDetails + }); + + // Wait for the first connection task to finish + connectTask.Wait(); + + // Verify that the first connection was cancelled (no connection was created) + Assert.Null(connectTask.Result.ConnectionId); + + // Verify that the second connection succeeded + Assert.NotEmpty(connectResult.ConnectionId); + } + + [Fact] + public void CanCancelConnectRequestByDisconnecting() + { + var testFile = "file:///my/test/file.sql"; + + // Given a connection that times out and responds to cancellation + var mockConnection = new Mock { CallBase = true }; + CancellationToken token; + bool ready = false; + mockConnection.Setup(x => x.OpenAsync(Moq.It.IsAny())) + .Callback(t => + { + // Pass the token to the return handler and signal the main thread to cancel + token = t; + ready = true; + }) + .Returns(() => + { + if (TestUtils.WaitFor(() => token.IsCancellationRequested)) + { + throw new OperationCanceledException(); + } + else + { + return Task.FromResult(true); + } + }); + + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(mockConnection.Object); + + + var connectionService = new ConnectionService(mockFactory.Object); + + // Connect the first connection asynchronously in a background thread + var connectionDetails = TestObjects.GetTestConnectionDetails(); + var connectTask = Task.Run(async () => + { + return await connectionService + .Connect(new ConnectParams() + { + OwnerUri = testFile, + Connection = connectionDetails + }); + }); + + // Wait for the connection to call OpenAsync() + Assert.True(TestUtils.WaitFor(() => ready)); + + // Send a cancellation by trying to disconnect + var disconnectResult = connectionService + .Disconnect(new DisconnectParams() + { + OwnerUri = testFile + }); + + // Wait for the first connection task to finish + connectTask.Wait(); + + // Verify that the first connection was cancelled (no connection was created) + Assert.Null(connectTask.Result.ConnectionId); + + // Verify that the disconnect failed (since it caused a cancellation) + Assert.False(disconnectResult); + } + /// /// Verify that we can connect to the default database when no database name is /// provided as a parameter. @@ -59,12 +268,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection [Theory] [InlineDataAttribute(null)] [InlineDataAttribute("")] - public void CanConnectWithEmptyDatabaseName(string databaseName) + public async void CanConnectWithEmptyDatabaseName(string databaseName) { // Connect var connectionDetails = TestObjects.GetTestConnectionDetails(); connectionDetails.DatabaseName = databaseName; - var connectionResult = + var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { @@ -83,7 +292,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection [Theory] [InlineDataAttribute("master")] [InlineDataAttribute("nonMasterDb")] - public void ConnectToDefaultDatabaseRespondsWithActualDbName(string expectedDbName) + public async void ConnectToDefaultDatabaseRespondsWithActualDbName(string expectedDbName) { // Given connecting with empty database name will return the expected DB name var connectionMock = new Mock { CallBase = true }; @@ -99,7 +308,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection var connectionDetails = TestObjects.GetTestConnectionDetails(); connectionDetails.DatabaseName = string.Empty; - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -118,14 +327,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// connection, we disconnect first before connecting. /// [Fact] - public void ConnectingWhenConnectionExistCausesDisconnectThenConnect() + public async void ConnectingWhenConnectionExistCausesDisconnectThenConnect() { bool callbackInvoked = false; // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -146,7 +355,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection ); // send annother connect request - connectionResult = + connectionResult = await connectionService .Connect(new ConnectParams() { @@ -165,7 +374,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verify that when connecting with invalid credentials, an error is thrown. /// [Fact] - public void ConnectingWithInvalidCredentialsYieldsErrorMessage() + public async void ConnectingWithInvalidCredentialsYieldsErrorMessage() { var testConnectionDetails = TestObjects.GetTestConnectionDetails(); var invalidConnectionDetails = new ConnectionDetails(); @@ -175,7 +384,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection invalidConnectionDetails.Password = "invalidPassword"; // Connect to test db with invalid credentials - var connectionResult = + var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { @@ -204,10 +413,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection [InlineData("Integrated", "file://my/sample/file.sql", null, "test", "sa", "123456")] [InlineData("Integrated", "", "my-server", "test", "sa", "123456")] [InlineData("Integrated", "file://my/sample/file.sql", "", "test", "sa", "123456")] - public void ConnectingWithInvalidParametersYieldsErrorMessage(string authType, string ownerUri, string server, string database, string userName, string password) + public async void ConnectingWithInvalidParametersYieldsErrorMessage(string authType, string ownerUri, string server, string database, string userName, string password) { // Connect with invalid parameters - var connectionResult = + var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { @@ -238,10 +447,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection [InlineData("sa", "")] [InlineData(null, "12345678")] [InlineData("", "12345678")] - public void ConnectingWithNoUsernameOrPasswordWorksForIntegratedAuth(string userName, string password) + public async void ConnectingWithNoUsernameOrPasswordWorksForIntegratedAuth(string userName, string password) { // Connect - var connectionResult = + var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { @@ -263,10 +472,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verify that when connecting with a null parameters object, an error is thrown. /// [Fact] - public void ConnectingWithNullParametersObjectYieldsErrorMessage() + public async void ConnectingWithNullParametersObjectYieldsErrorMessage() { // Connect with null parameters - var connectionResult = + var connectionResult = await TestObjects.GetTestConnectionService() .Connect(null); @@ -330,7 +539,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verify that a connection changed event is fired when the database context changes. /// [Fact] - public void ConnectionChangedEventIsFiredWhenDatabaseContextChanges() + public async void ConnectionChangedEventIsFiredWhenDatabaseContextChanges() { var serviceHostMock = new Mock(); @@ -339,7 +548,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection // Set up an initial connection string ownerUri = "file://my/sample/file.sql"; - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -364,11 +573,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verify that the SQL parser correctly detects errors in text /// [Fact] - public void ConnectToDatabaseTest() + public async void ConnectToDatabaseTest() { // connect to a database instance string ownerUri = "file://my/sample/file.sql"; - var connectionResult = + var connectionResult = await TestObjects.GetTestConnectionService() .Connect(new ConnectParams() { @@ -384,12 +593,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verify that we can disconnect from an active connection succesfully /// [Fact] - public void DisconnectFromDatabaseTest() + public async void DisconnectFromDatabaseTest() { // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -414,14 +623,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Test that when a disconnect is performed, the callback event is fired /// [Fact] - public void DisconnectFiresCallbackEvent() + public async void DisconnectFiresCallbackEvent() { bool callbackInvoked = false; // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -458,12 +667,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Test that disconnecting an active connection removes the Owner URI -> ConnectionInfo mapping /// [Fact] - public void DisconnectRemovesOwnerMapping() + public async void DisconnectRemovesOwnerMapping() { // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -498,12 +707,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection [InlineDataAttribute(null)] [InlineDataAttribute("")] - public void DisconnectValidatesParameters(string disconnectUri) + public async void DisconnectValidatesParameters(string disconnectUri) { // first connect string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -530,7 +739,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verifies the the list databases operation lists database names for the server used by a connection. /// [Fact] - public void ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames() + public async void ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames() { // Result set for the query of database names Dictionary[] data = @@ -550,7 +759,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection // connect to a database instance string ownerUri = "file://my/sample/file.sql"; - var connectionResult = + var connectionResult = await connectionService .Connect(new ConnectParams() { @@ -579,7 +788,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verify that the SQL parser correctly detects errors in text /// [Fact] - public void OnConnectionCallbackHandlerTest() + public async void OnConnectionCallbackHandlerTest() { bool callbackInvoked = false; @@ -593,7 +802,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection ); // connect to a database instance - var connectionResult = connectionService.Connect(TestObjects.GetTestConnectionParams()); + var connectionResult = await connectionService.Connect(TestObjects.GetTestConnectionParams()); // verify that a valid connection id was returned Assert.True(callbackInvoked); @@ -603,14 +812,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Connection /// Verify when a connection is created that the URI -> Connection mapping is created in the connection service. /// [Fact] - public void TestConnectRequestRegistersOwner() + public async void TestConnectRequestRegistersOwner() { // Given a request to connect to a database var service = TestObjects.GetTestConnectionService(); var connectParams = TestObjects.GetTestConnectionParams(); // connect to a database instance - var connectionResult = service.Connect(connectParams); + var connectionResult = await service.Connect(connectParams); // verify that a valid connection id was returned Assert.NotNull(connectionResult.ConnectionId); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs index 1a925a88..7d5422ed 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/LanguageServer/LanguageServiceTests.cs @@ -172,7 +172,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.LanguageServices string ownerUri = "file://my/sample/file.sql"; var connectionService = TestObjects.GetTestConnectionService(); var connectionResult = - connectionService + await connectionService .Connect(new ConnectParams() { OwnerUri = ownerUri, diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs index 087a87cb..d27fe156 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -19,7 +19,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public class CancelTests { [Fact] - public void CancelInProgressQueryTest() + public async void CancelInProgressQueryTest() { // Set up file for returning the query var fileMock = new Mock(); @@ -32,7 +32,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I request a query (doesn't matter what kind) and execute it - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.GetSubSectionDocument(), OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); @@ -55,7 +55,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void CancelExecutedQueryTest() + public async void CancelExecutedQueryTest() { // Set up file for returning the query @@ -67,7 +67,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request a query (doesn't matter what kind) and wait for execution - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); @@ -89,13 +89,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void CancelNonExistantTest() + public async void CancelNonExistantTest() { var workspaceService = new Mock>(); // If: // ... I request to cancel a query that doesn't exist - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); var cancelParams = new QueryCancelParams {OwnerUri = "Doesn't Exist"}; QueryCancelResult result = null; var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 78b708d9..d9391970 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -10,6 +10,7 @@ using System.Data.Common; using System.IO; using System.Data.SqlClient; using System.Threading; +using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlServer.Management.Common; @@ -277,12 +278,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution }; } - public static QueryExecutionService GetPrimedExecutionService(ISqlConnectionFactory factory, bool isConnected, WorkspaceService workspaceService) + public static async Task GetPrimedExecutionService(ISqlConnectionFactory factory, bool isConnected, WorkspaceService workspaceService) { var connectionService = new ConnectionService(factory); if (isConnected) { - connectionService.Connect(new ConnectParams + await connectionService.Connect(new ConnectParams { Connection = GetTestConnectionDetails(), OwnerUri = OwnerUri diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs index 2837e892..b3ff5efd 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs @@ -37,7 +37,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void DisposeExecutedQuery() + public async void DisposeExecutedQuery() { // Set up file for returning the query var fileMock = new Mock(); @@ -48,7 +48,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request a query (doesn't matter what kind) - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -70,12 +70,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void QueryDisposeMissingQuery() + public async void QueryDisposeMissingQuery() { var workspaceService = new Mock>(); // If: // ... I attempt to dispose a query that doesn't exist - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; QueryDisposeResult result = null; var disposeRequest = GetQueryDisposeResultContextMock(qdr => result = qdr, null); @@ -99,7 +99,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) .Returns(fileMock.Object); // ... We need a query service - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); // If: diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index c7e8ac0d..2484e233 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -418,7 +418,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution #region Service Tests [Fact] - public void QueryExecuteValidNoResultsTest() + public async void QueryExecuteValidNoResultsTest() { // Given: // ... Default settings are stored in the workspace service @@ -433,7 +433,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request to execute a valid query with no results - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var queryParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; QueryExecuteResult result = null; @@ -461,7 +461,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void QueryExecuteValidResultsTest() + public async void QueryExecuteValidResultsTest() { // Set up file for returning the query @@ -473,7 +473,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request to execute a valid query with results - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true, + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; @@ -503,13 +503,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void QueryExecuteUnconnectedUriTest() + public async void QueryExecuteUnconnectedUriTest() { var workspaceService = new Mock>(); // If: // ... I request to execute a query using a file URI that isn't connected - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QuerySelection = Common.WholeDocument }; QueryExecuteResult result = null; @@ -528,7 +528,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void QueryExecuteInProgressTest() + public async void QueryExecuteInProgressTest() { // Set up file for returning the query @@ -541,7 +541,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I request to execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; // Note, we don't care about the results of the first request @@ -566,7 +566,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void QueryExecuteCompletedTest() + public async void QueryExecuteCompletedTest() { // Set up file for returning the query @@ -579,7 +579,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I request to execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; // Note, we don't care about the results of the first request @@ -606,7 +606,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Theory] [InlineData(null)] - public void QueryExecuteMissingSelectionTest(SelectionData selection) + public async void QueryExecuteMissingSelectionTest(SelectionData selection) { // Set up file for returning the query @@ -618,7 +618,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request to execute a query with a missing query string - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = selection }; QueryExecuteResult result = null; @@ -639,7 +639,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void QueryExecuteInvalidQueryTest() + public async void QueryExecuteInvalidQueryTest() { // Set up file for returning the query var fileMock = new Mock(); @@ -650,7 +650,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I request to execute a query that is invalid - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, true), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, true), true, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; QueryExecuteResult result = null; diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs index 153e2977..e3c38ab5 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs @@ -26,7 +26,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test save results to a file as CSV with correct parameters /// [Fact] - public void SaveResultsAsCsvSuccessTest() + public async void SaveResultsAsCsvSuccessTest() { // Set up file for returning the query @@ -37,7 +37,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) .Returns(fileMock.Object); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -72,7 +72,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test save results to a file as CSV with a selection of cells and correct parameters /// [Fact] - public void SaveResultsAsCsvWithSelectionSuccessTest() + public async void SaveResultsAsCsvWithSelectionSuccessTest() { // Set up file for returning the query @@ -84,7 +84,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -123,7 +123,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test handling exception in saving results to CSV file /// [Fact] - public void SaveResultsAsCsvExceptionTest() + public async void SaveResultsAsCsvExceptionTest() { // Set up file for returning the query @@ -135,7 +135,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -164,12 +164,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test saving results to CSV file when the requested result set is no longer active /// [Fact] - public void SaveResultsAsCsvQueryNotFoundTest() + public async void SaveResultsAsCsvQueryNotFoundTest() { var workspaceService = new Mock>(); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -196,7 +196,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test save results to a file as JSON with correct parameters /// [Fact] - public void SaveResultsAsJsonSuccessTest() + public async void SaveResultsAsJsonSuccessTest() { // Set up file for returning the query @@ -207,7 +207,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) .Returns(fileMock.Object); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -241,7 +241,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test save results to a file as JSON with a selection of cells and correct parameters /// [Fact] - public void SaveResultsAsJsonWithSelectionSuccessTest() + public async void SaveResultsAsJsonWithSelectionSuccessTest() { // Set up file for returning the query var fileMock = new Mock(); @@ -252,7 +252,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -290,7 +290,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test handling exception in saving results to JSON file /// [Fact] - public void SaveResultsAsJsonExceptionTest() + public async void SaveResultsAsJsonExceptionTest() { // Set up file for returning the query var fileMock = new Mock(); @@ -300,7 +300,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) .Returns(fileMock.Object); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -329,11 +329,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// Test saving results to JSON file when the requested result set is no longer active /// [Fact] - public void SaveResultsAsJsonQueryNotFoundTest() + public async void SaveResultsAsJsonQueryNotFoundTest() { var workspaceService = new Mock>(); // Execute a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs index 8fcc9386..7b57971b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -142,7 +142,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I have a query that has results (doesn't matter what) - var queryService = Common.GetPrimedExecutionService( + var queryService = await Common.GetPrimedExecutionService( Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; @@ -165,13 +165,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void SubsetServiceMissingQueryTest() + public async void SubsetServiceMissingQueryTest() { var workspaceService = new Mock>(); // If: // ... I ask for a set of results for a file that hasn't executed a query - var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; QueryExecuteSubsetResult result = null; var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); @@ -187,7 +187,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void SubsetServiceUnexecutedQueryTest() + public async void SubsetServiceUnexecutedQueryTest() { // Set up file for returning the query @@ -199,7 +199,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution .Returns(fileMock.Object); // If: // ... I have a query that hasn't finished executing (doesn't matter what) - var queryService = Common.GetPrimedExecutionService( + var queryService = await Common.GetPrimedExecutionService( Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; @@ -223,13 +223,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Fact] - public void SubsetServiceOutOfRangeSubsetTest() + public async void SubsetServiceOutOfRangeSubsetTest() { var workspaceService = new Mock>(); // If: // ... I have a query that doesn't have any result sets - var queryService = Common.GetPrimedExecutionService( + var queryService = await Common.GetPrimedExecutionService( Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs index 9a5f8ce1..b2d52180 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestUtils.cs @@ -1,5 +1,6 @@ using System; using System.Runtime.InteropServices; +using System.Threading; namespace Microsoft.SqlTools.ServiceLayer.Test.Utility { @@ -21,5 +22,23 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility test(); } } + + /// + /// Wait for a condition to be true for a limited amount of time. + /// + /// Function that returns a boolean on a condition + /// Number of milliseconds to wait between test intervals. + /// Number of test intervals to perform before giving up. + /// True if the condition was met before the test interval limit. + public static bool WaitFor(Func condition, int intervalMilliseconds = 10, int intervalCount = 200) + { + int count = 0; + while (count++ < intervalCount && !condition.Invoke()) + { + Thread.Sleep(intervalMilliseconds); + } + + return (count < intervalCount); + } } }