diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs new file mode 100644 index 00000000..3eb87f4f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryCancelRequest.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query cancellation request + /// + public class QueryCancelParams + { + public string OwnerUri { get; set; } + } + + /// + /// Parameters to return as the result of a query dispose request + /// + public class QueryCancelResult + { + /// + /// Any error messages that occurred during disposing the result set. Optional, can be set + /// to null if there were no errors. + /// + public string Messages { get; set; } + } + + public class QueryCancelRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/cancel"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index 887bbbaf..ec25fd6c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -233,6 +233,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }; } + /// + /// Cancels the query by issuing the cancellation token + /// + public void Cancel() + { + // Make sure that the query hasn't completed execution + if (HasExecuted) + { + throw new InvalidOperationException("The query has already completed, it cannot be cancelled."); + } + + // Issue the cancellation token for the query + cancellationSource.Cancel(); + } + /// /// Delegate handler for storing messages that are returned from the server /// NOTE: Only messages that are below a certain severity will be returned via this diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 389be092..942beef3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -73,6 +73,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution serviceHost.SetRequestHandler(QueryExecuteRequest.Type, HandleExecuteRequest); serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest); serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest); + serviceHost.SetRequestHandler(QueryCancelRequest.Type, HandleCancelRequest); // Register handler for shutdown event serviceHost.RegisterShutdownTask((shutdownParams, requestContext) => @@ -178,6 +179,51 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } + public async Task HandleCancelRequest(QueryCancelParams cancelParams, + RequestContext requestContext) + { + try + { + // Attempt to find the query for the owner uri + Query result; + if (!ActiveQueries.TryGetValue(cancelParams.OwnerUri, out result)) + { + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Failed to cancel query, ID not found." + }); + return; + } + + // Cancel the query + result.Cancel(); + + // Attempt to dispose the query + if (!ActiveQueries.TryRemove(cancelParams.OwnerUri, out result)) + { + // It really shouldn't be possible to get to this scenario, but we'll cover it anyhow + await requestContext.SendResult(new QueryCancelResult + { + Messages = "Query successfully cancelled, failed to dispose query. ID not found." + }); + return; + } + + await requestContext.SendResult(new QueryCancelResult()); + } + catch (InvalidOperationException e) + { + await requestContext.SendResult(new QueryCancelResult + { + Messages = e.Message + }); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + #endregion #region Private Helpers diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs new file mode 100644 index 00000000..dbc344f8 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -0,0 +1,124 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class CancelTests + { + [Fact] + public void CancelInProgressQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and execute it + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri }; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a successful event (no messages) + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.Null(result.Messages); + + // ... The query should have been disposed as well + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void CancelExecutedQueryTest() + { + // If: + // ... I request a query (doesn't matter what kind) and wait for execution + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams {QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri}; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And then I request to cancel the query + var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + + // ... The query should not have been disposed + Assert.NotEmpty(queryService.ActiveQueries); + } + + [Fact] + public void CancelNonExistantTest() + { + // If: + // ... I request to cancel a query that doesn't exist + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var cancelParams = new QueryCancelParams {OwnerUri = "Doesn't Exist"}; + QueryCancelResult result = null; + var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); + queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + + // Then: + // ... I should have seen a result event with an error message + VerifyQueryCancelCallCount(cancelRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + } + + #region Mocking + + private static Mock> GetQueryCancelResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private static void VerifyQueryCancelCallCount(Mock> mock, + Times sendResultCalls, Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +}