diff --git a/Common.props b/Common.props index 4e75c1bf..18c43558 100644 --- a/Common.props +++ b/Common.props @@ -1,7 +1,7 @@ $(MSBuildAllProjects);$(MSBuildThisFileFullPath) - 150.18085.0-preview + 150.18096.0-preview true true diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index 1fda8fad..e5c85450 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -226,7 +226,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { get { - if (!HasExecuted) + if (!HasExecuted && !HasCancelled) { throw new InvalidOperationException("Query has not been executed."); } @@ -259,6 +259,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } + /// + /// if the query has been cancelled (before execution started) + /// + public bool HasCancelled { get; private set; } + /// /// The text of the query to execute /// @@ -280,6 +285,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } // Issue the cancellation token for the query + this.HasCancelled = true; cancellationSource.Cancel(); } @@ -368,9 +374,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution ReliableSqlConnection sqlConn = null; try { + // check for cancellation token before actually making connection + cancellationSource.Token.ThrowIfCancellationRequested(); + // Mark that we've internally executed hasExecuteBeenCalled = true; - + // Don't actually execute if there aren't any batches to execute if (Batches.Length == 0) { @@ -429,6 +438,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } catch (Exception e) { + if (e is OperationCanceledException) + { + await BatchMessageSent(new ResultMessage(SR.QueryServiceQueryCancelled, false, null)); + } // Call the query failure callback if (QueryFailed != null) { diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index d48b1153..027ba542 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -109,6 +109,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// internal ConcurrentDictionary ActiveQueries => queries.Value; + /// + /// Internal task for testability + /// + internal Task WorkTask { get; private set; } + /// /// Instance of the connection service, used to get the connection info for a given owner URI /// @@ -130,7 +135,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Holds a map from the simple execute unique GUID and the underlying task that is being ran /// - private readonly Lazy> simpleExecuteRequests = + private readonly Lazy> simpleExecuteRequests = new Lazy>(() => new ConcurrentDictionary()); /// @@ -177,23 +182,34 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Handles request to execute a selection of a document in the workspace service /// - internal Task HandleExecuteRequest(ExecuteRequestParamsBase executeParams, + internal async Task HandleExecuteRequest(ExecuteRequestParamsBase executeParams, RequestContext requestContext) { - // Setup actions to perform upon successful start and on failure to start - Func> queryCreateSuccessAction = async q => { - await requestContext.SendResult(new ExecuteRequestResult()); - Logger.Write(TraceEventType.Stop, $"Response for Query: '{executeParams.OwnerUri} sent. Query Complete!"); - return true; - }; - Func queryCreateFailureAction = message => + try { - Logger.Write(TraceEventType.Warning, $"Failed to create Query: '{executeParams.OwnerUri}. Message: '{message}' Complete!"); - return requestContext.SendError(message); - }; + // Setup actions to perform upon successful start and on failure to start + Func> queryCreateSuccessAction = async q => + { + await requestContext.SendResult(new ExecuteRequestResult()); + Logger.Write(TraceEventType.Stop, $"Response for Query: '{executeParams.OwnerUri} sent. Query Complete!"); + return true; + }; + Func queryCreateFailureAction = message => + { + Logger.Write(TraceEventType.Warning, $"Failed to create Query: '{executeParams.OwnerUri}. Message: '{message}' Complete!"); + return requestContext.SendError(message); + }; - // Use the internal handler to launch the query - return InterServiceExecuteQuery(executeParams, null, requestContext, queryCreateSuccessAction, queryCreateFailureAction, null, null); + // Use the internal handler to launch the query + WorkTask = Task.Run(async () => + { + await InterServiceExecuteQuery(executeParams, null, requestContext, queryCreateSuccessAction, queryCreateFailureAction, null, null); + }); + } + catch (Exception ex) + { + await requestContext.SendError(ex.ToString()); + } } /// @@ -219,14 +235,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution await requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri); return; } - + ConnectParams connectParams = new ConnectParams { OwnerUri = randomUri, Connection = connInfo.ConnectionDetails, Type = ConnectionType.Default }; - + Task workTask = Task.Run(async () => { await ConnectionService.Connect(connectParams); @@ -243,26 +259,26 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution try { // check to make sure any results were recieved - if (query.Batches.Length == 0 - || query.Batches[0].ResultSets.Count == 0) + if (query.Batches.Length == 0 + || query.Batches[0].ResultSets.Count == 0) { await requestContext.SendError(SR.QueryServiceResultSetHasNoResults); return; - } + } long rowCount = query.Batches[0].ResultSets[0].RowCount; // check to make sure there is a safe amount of rows to load into memory - if (rowCount > Int32.MaxValue) + if (rowCount > Int32.MaxValue) { await requestContext.SendError(SR.QueryServiceResultSetTooLarge); return; } - + SimpleExecuteResult result = new SimpleExecuteResult { RowCount = rowCount, ColumnInfo = query.Batches[0].ResultSets[0].Columns, - Rows = new DbCellValue[0][] + Rows = new DbCellValue[0][] }; if (rowCount > 0) @@ -280,8 +296,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution result.Rows = subset.Rows; } await requestContext.SendResult(result); - } - finally + } + finally { Query removedQuery; Task removedTask; @@ -306,7 +322,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution ActiveSimpleExecuteRequests.TryAdd(randomUri, workTask); } - catch(Exception ex) + catch (Exception ex) { await requestContext.SendError(ex.ToString()); } @@ -335,7 +351,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } - /// + /// /// Handles a request to get an execution plan /// internal async Task HandleExecutionPlanRequest(QueryExecutionPlanParams planParams, @@ -502,17 +518,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Callback to call when query has completed execution with errors. May be null. /// - public async Task InterServiceExecuteQuery(ExecuteRequestParamsBase executeParams, + public async Task InterServiceExecuteQuery(ExecuteRequestParamsBase executeParams, ConnectionInfo connInfo, IEventSender queryEventSender, Func> queryCreateSuccessFunc, Func queryCreateFailFunc, - Query.QueryAsyncEventHandler querySuccessFunc, + Query.QueryAsyncEventHandler querySuccessFunc, Query.QueryAsyncErrorEventHandler queryFailureFunc) { Validate.IsNotNull(nameof(executeParams), executeParams); Validate.IsNotNull(nameof(queryEventSender), queryEventSender); - + Query newQuery; try { @@ -622,7 +638,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // if any oldQuery exists on the executeParams.OwnerUri but it has not yet executed, // then shouldn't we cancel and clean out that query since we are about to create a new query object on the current OwnerUri. // - if (ActiveQueries.TryGetValue(executeParams.OwnerUri, out oldQuery) && oldQuery.HasExecuted) + if (ActiveQueries.TryGetValue(executeParams.OwnerUri, out oldQuery) && (oldQuery.HasExecuted || oldQuery.HasCancelled)) { oldQuery.Dispose(); ActiveQueries.TryRemove(executeParams.OwnerUri, out oldQuery); @@ -815,7 +831,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return GetSqlTextFromSelectionData(docRequest.OwnerUri, docRequest.QuerySelection); } - // If it is a document statement, we'll retrieve the text from the document + // If it is a document statement, we'll retrieve the text from the document ExecuteDocumentStatementParams stmtRequest = request as ExecuteDocumentStatementParams; if (stmtRequest != null) { diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/CancelTests.cs index 220e871e..3230a21e 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/CancelTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/CancelTests.cs @@ -4,6 +4,8 @@ // using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests; using Microsoft.SqlTools.ServiceLayer.SqlContext; @@ -29,6 +31,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; queryService.ActiveQueries[Constants.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution @@ -42,8 +45,9 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution await queryService.HandleCancelRequest(cancelParams, cancelRequest.Object); // Then: - // ... The query should not have been disposed + // ... The query should not have been disposed but should have been cancelled Assert.Equal(1, queryService.ActiveQueries.Count); + Assert.Equal(true, queryService.ActiveQueries[Constants.OwnerUri].HasCancelled); cancelRequest.Validate(); } @@ -58,6 +62,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; // ... And then I request to cancel the query @@ -71,8 +76,9 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution await queryService.HandleCancelRequest(cancelParams, cancelRequest.Object); // Then: - // ... The query should not have been disposed + // ... The query should not have been disposed and cancel should not have excecuted Assert.NotEmpty(queryService.ActiveQueries); + Assert.Equal(false, queryService.ActiveQueries[Constants.OwnerUri].HasCancelled); cancelRequest.Validate(); } @@ -93,5 +99,40 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution await queryService.HandleCancelRequest(cancelParams, cancelRequest.Object); cancelRequest.Validate(); } + + [Fact] + public async Task CancelQueryBeforeExecutionStartedTest() + { + // Setup query settings + QueryExecutionSettings querySettings = new QueryExecutionSettings + { + ExecutionPlanOptions = new ExecutionPlanOptions + { + IncludeActualExecutionPlanXml = false, + IncludeEstimatedExecutionPlanXml = true + } + }; + + // Create query with a failure callback function + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false, false); + ConnectionService.Instance.OwnerToConnectionMap[ci.OwnerUri] = ci; + Query query = new Query(Constants.StandardQuery, ci, querySettings, MemoryFileSystem.GetFileStreamFactory()); + + string errorMessage = null; + Query.QueryAsyncErrorEventHandler failureCallback = async (q, e) => + { + errorMessage = "Error Occured"; + }; + query.QueryFailed += failureCallback; + + query.Cancel(); + query.Execute(); + await query.ExecutionTask; + + // Validate that query has not been executed but cancelled and query failed called function was called + Assert.Equal(true, query.HasCancelled); + Assert.Equal(false, query.HasExecuted); + Assert.Equal("Error Occured", errorMessage); + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs index 9c03f23e..0c883a1b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Common.cs @@ -183,6 +183,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution HostingProtocol.RequestContext requestContext) { await service.HandleExecuteRequest(qeParams, requestContext); + await service.WorkTask; if (service.ActiveQueries.ContainsKey(qeParams.OwnerUri) && service.ActiveQueries[qeParams.OwnerUri].ExecutionTask != null) { await service.ActiveQueries[qeParams.OwnerUri].ExecutionTask; diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/DisposeTests.cs index 56f480fc..9be9ab3d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/DisposeTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/DisposeTests.cs @@ -44,6 +44,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution var executeParams = new ExecuteDocumentSelectionParams {QuerySelection = null, OwnerUri = Constants.OwnerUri}; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; // ... And then I dispose of the query @@ -90,6 +91,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution var queryParams = new ExecuteDocumentSelectionParams { QuerySelection = Common.WholeDocument, OwnerUri = Constants.OwnerUri }; var requestContext = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(queryParams, requestContext.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; // ... And it sticks around as an active query diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/ExecutionPlanTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/ExecutionPlanTests.cs index 6064b316..d8f53a07 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/ExecutionPlanTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/ExecutionPlanTests.cs @@ -152,6 +152,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution }; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; // ... And I then ask for a valid execution plan @@ -201,6 +202,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution }; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; queryService.ActiveQueries[Constants.OwnerUri].Batches[0].ResultSets[0].hasStartedRead = false; @@ -232,6 +234,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution }; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; // ... And I then ask for an execution plan from a result set diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ServiceIntegrationTests.cs index 37b7cda2..8969ba58 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ServiceIntegrationTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/ServiceIntegrationTests.cs @@ -61,6 +61,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams { QuerySelection = null, OwnerUri = Constants.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set and get it to throw because of invalid column selection @@ -106,6 +107,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams {QuerySelection = null, OwnerUri = Constants.OwnerUri}; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set from a query @@ -173,6 +175,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams { QuerySelection = null, OwnerUri = Constants.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set and get it to throw because of invalid column selection @@ -216,6 +219,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams { QuerySelection = null, OwnerUri = Constants.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set from a query @@ -282,6 +286,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams {QuerySelection = null, OwnerUri = Constants.OwnerUri}; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set and get it to throw because of invalid column selection @@ -325,6 +330,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams {QuerySelection = null, OwnerUri = Constants.OwnerUri}; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set from a query @@ -393,6 +399,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams { QuerySelection = null, OwnerUri = Constants.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set and get it to throw because of invalid column selection @@ -436,6 +443,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults var executeParams = new ExecuteDocumentSelectionParams { QuerySelection = null, OwnerUri = Constants.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await qes.HandleExecuteRequest(executeParams, executeRequest.Object); + await qes.WorkTask; await qes.ActiveQueries[Constants.OwnerUri].ExecutionTask; // If: I attempt to save a result set from a query diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SubsetTests.cs index f225b9f5..587073fe 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SubsetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SubsetTests.cs @@ -140,6 +140,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution var executeParams = new ExecuteDocumentSelectionParams {QuerySelection = null, OwnerUri = Constants.OwnerUri}; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; // ... And I then ask for a valid set of results from it @@ -179,6 +180,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution var executeParams = new ExecuteDocumentSelectionParams { QuerySelection = null, OwnerUri = Constants.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; queryService.ActiveQueries[Constants.OwnerUri].Batches[0].ResultSets[0].hasStartedRead = false; @@ -201,6 +203,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution var executeParams = new ExecuteDocumentSelectionParams { QuerySelection = null, OwnerUri = Constants.OwnerUri }; var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.WorkTask; await queryService.ActiveQueries[Constants.OwnerUri].ExecutionTask; // ... And I then ask for a set of results from it