From 60edcc3057ae6bf2d79aad180d826b5bce229f2d Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Tue, 11 Oct 2016 10:51:52 -0700 Subject: [PATCH] Make query execution truly asynchronous (#83) The two main changes in this pull request: Launching query execution as an asynchronous task that performs a callback upon completion or failure of a query. (Which also sets us up for callbacks progressive results) Moving away from using the Result of a query execution to return an error. Instead we'll use an error event to return an error Additionally, some nice refactoring and cleaning up of the unit tests to take advantage of the cool RequestContext mock tooling by @kevcunnane * Initial commit of refactor to run execution truely asynchronously * Moving the storage of the task into Query class Callbacks for completion of a query and failure of a query are setup as events in the Query class. This actually sets us up for a very nice framework for adding batch and resultset completion callbacks. However, this also exposes a problem with cancelling queries and returning errors -- we don't properly handle errors during execution of a query (aside from DB errors). * Wrapping things up in order to submit for code review * Adding fixes as per comments --- .../QueryExecution/Batch.cs | 24 ++- .../QueryExecution/Query.cs | 95 ++++++++--- .../QueryExecution/QueryExecutionService.cs | 65 ++++---- src/Microsoft.SqlTools.ServiceLayer/sr.cs | 21 ++- src/Microsoft.SqlTools.ServiceLayer/sr.resx | 9 + .../sr.strings | 4 + .../QueryExecution/CancelTests.cs | 10 +- .../QueryExecution/Common.cs | 20 ++- .../QueryExecution/DisposeTests.cs | 5 +- .../QueryExecution/ExecuteTests.cs | 102 +++++++----- .../QueryExecution/SaveResultsTests.cs | 156 ++++-------------- .../QueryExecution/SubsetTests.cs | 43 ++--- 12 files changed, 272 insertions(+), 282 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs index 51a35a7d..1bbce090 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -31,7 +31,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution private bool disposed; /// - /// Factory for creating readers/writrs for the output of the batch + /// Factory for creating readers/writers for the output of the batch /// private readonly IFileStreamFactory outputFileFactory; @@ -151,7 +151,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } // Make sure we aren't using a ReliableCommad since we do not want automatic retry - Debug.Assert(!(command is ReliableSqlConnection.ReliableSqlCommand), "ReliableSqlCommand command should not be used to execute queries"); + Debug.Assert(!(command is ReliableSqlConnection.ReliableSqlCommand), + "ReliableSqlCommand command should not be used to execute queries"); // Create a command that we'll use for executing the query using (command) @@ -170,18 +171,19 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { // Create a message with the number of affected rows -- IF the query affects rows resultMessages.Add(new ResultMessage(reader.RecordsAffected >= 0 - ? SR.QueryServiceAffectedRows(reader.RecordsAffected) - : SR.QueryServiceCompletedSuccessfully)); + ? SR.QueryServiceAffectedRows(reader.RecordsAffected) + : SR.QueryServiceCompletedSuccessfully)); continue; } // This resultset has results (ie, SELECT/etc queries) - // Read until we hit the end of the result set ResultSet resultSet = new ResultSet(reader, outputFileFactory); - await resultSet.ReadResultToEnd(cancellationToken); - + // Add the result set to the results of the query resultSets.Add(resultSet); + + // Read until we hit the end of the result set + await resultSet.ReadResultToEnd(cancellationToken); // Add a message for the number of rows the query returned resultMessages.Add(new ResultMessage(SR.QueryServiceAffectedRows(resultSet.RowCount))); @@ -194,9 +196,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution HasError = true; UnwrapDbException(dbe); } - catch (Exception) + catch (TaskCanceledException) + { + resultMessages.Add(new ResultMessage(SR.QueryServiceQueryCancelled)); + throw; + } + catch (Exception e) { HasError = true; + resultMessages.Add(new ResultMessage(SR.QueryServiceQueryFailed(e.Message))); throw; } finally diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index cf2df73c..1c48d516 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -96,6 +96,22 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Properties + /// + /// Delegate type for callback when a query completes or fails + /// + /// The query that completed + public delegate Task QueryAsyncEventHandler(Query q); + + /// + /// Callback for when the query has completed successfully + /// + public event QueryAsyncEventHandler QueryCompleted; + + /// + /// Callback for when the query has failed + /// + public event QueryAsyncEventHandler QueryFailed; + /// /// The batches underneath this query /// @@ -124,6 +140,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } + internal Task ExecutionTask { get; private set; } + /// /// Whether or not the query has completed executed, regardless of success or failure /// @@ -167,10 +185,44 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution cancellationSource.Cancel(); } + public void Execute() + { + ExecutionTask = Task.Run(ExecuteInternal); + } + + /// + /// Retrieves a subset of the result sets + /// + /// The index for selecting the batch item + /// The index for selecting the result set + /// The starting row of the results + /// How many rows to retrieve + /// A subset of results + public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) + { + // Sanity check that the results are available + if (!HasExecuted) + { + throw new InvalidOperationException(SR.QueryServiceSubsetNotCompleted); + } + + // Sanity check to make sure that the batch is within bounds + if (batchIndex < 0 || batchIndex >= Batches.Length) + { + throw new ArgumentOutOfRangeException(nameof(batchIndex), SR.QueryServiceSubsetBatchOutOfRange); + } + + return Batches[batchIndex].GetSubset(resultSetIndex, startRow, rowCount); + } + + #endregion + + #region Private Helpers + /// /// Executes this query asynchronously and collects all result sets /// - public async Task Execute() + private async Task ExecuteInternal() { // Mark that we've internally executed hasExecuteBeenCalled = true; @@ -202,6 +254,20 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { await b.Execute(conn, cancellationSource.Token); } + + // Call the query execution callback + if (QueryCompleted != null) + { + await QueryCompleted(this); + } + } + catch (Exception) + { + // Call the query failure callback + if (QueryFailed != null) + { + await QueryFailed(this); + } } finally { @@ -227,7 +293,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution throw new InvalidOperationException(SR.QueryServiceMessageSenderNotSql); } - foreach(SqlError error in args.Errors) + foreach (SqlError error in args.Errors) { // Did the database context change (error code 5701)? if (error.Number == DatabaseContextChangeErrorNumber) @@ -237,31 +303,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } - /// - /// Retrieves a subset of the result sets - /// - /// The index for selecting the batch item - /// The index for selecting the result set - /// The starting row of the results - /// How many rows to retrieve - /// A subset of results - public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) - { - // Sanity check that the results are available - if (!HasExecuted) - { - throw new InvalidOperationException(SR.QueryServiceSubsetNotCompleted); - } - - // Sanity check to make sure that the batch is within bounds - if (batchIndex < 0 || batchIndex >= Batches.Length) - { - throw new ArgumentOutOfRangeException(nameof(batchIndex), SR.QueryServiceSubsetBatchOutOfRange); - } - - return Batches[batchIndex].GetSubset(resultSetIndex, startRow, rowCount); - } - #endregion #region IDisposable Implementation diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index d0ff8d1a..e6abce65 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -129,19 +129,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public async Task HandleExecuteRequest(QueryExecuteParams executeParams, RequestContext requestContext) { - try - { - // Get a query new active query - Query newQuery = await CreateAndActivateNewQuery(executeParams, requestContext); + // Get a query new active query + Query newQuery = await CreateAndActivateNewQuery(executeParams, requestContext); - // Execute the query - await ExecuteAndCompleteQuery(executeParams, requestContext, newQuery); - } - catch (Exception e) - { - // Dump any unexpected exceptions as errors - await requestContext.SendError(e.Message); - } + // Execute the query -- asynchronously + await ExecuteAndCompleteQuery(executeParams, requestContext, newQuery); } public async Task HandleResultSubsetRequest(QueryExecuteSubsetParams subsetParams, @@ -399,7 +391,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { //get column name DbColumnWrapper col = selectedResultSet.Columns[i]; - string val = row[i]?.ToString(); + string val = row[i]; jsonWriter.WritePropertyName(col.ColumnName); if (val == null) { @@ -440,10 +432,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution ConnectionInfo connectionInfo; if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connectionInfo)) { - await requestContext.SendResult(new QueryExecuteResult - { - Messages = SR.QueryServiceQueryInvalidOwnerUri - }); + await requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri); return null; } @@ -488,24 +477,22 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution Query newQuery = new Query(queryText, connectionInfo, settings, BufferFileFactory); if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) { - await requestContext.SendResult(new QueryExecuteResult - { - Messages = SR.QueryServiceQueryInProgress - }); + await requestContext.SendError(SR.QueryServiceQueryInProgress); + newQuery.Dispose(); return null; } return newQuery; } - catch (ArgumentException ane) + catch (Exception e) { - await requestContext.SendResult(new QueryExecuteResult { Messages = ane.Message }); + await requestContext.SendError(e.Message); return null; } // Any other exceptions will fall through here and be collected at the end } - private async Task ExecuteAndCompleteQuery(QueryExecuteParams executeParams, RequestContext requestContext, Query query) + private static async Task ExecuteAndCompleteQuery(QueryExecuteParams executeParams, RequestContext requestContext, Query query) { // Skip processing if the query is null if (query == null) @@ -513,21 +500,29 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - // Launch the query and respond with successfully launching it - Task executeTask = query.Execute(); + // Setup the query completion/failure callbacks + Query.QueryAsyncEventHandler callback = async q => + { + // Send back the results + QueryExecuteCompleteParams eventParams = new QueryExecuteCompleteParams + { + OwnerUri = executeParams.OwnerUri, + BatchSummaries = q.BatchSummaries + }; + await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); + }; + + query.QueryCompleted += callback; + query.QueryFailed += callback; + + // Launch this as an asynchronous task + query.Execute(); + + // Send back a result showing we were successful await requestContext.SendResult(new QueryExecuteResult { Messages = null }); - - // Wait for query execution and then send back the results - await Task.WhenAll(executeTask); - QueryExecuteCompleteParams eventParams = new QueryExecuteCompleteParams - { - OwnerUri = executeParams.OwnerUri, - BatchSummaries = query.BatchSummaries - }; - await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); } #endregion diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.cs b/src/Microsoft.SqlTools.ServiceLayer/sr.cs index 811ab975..dbe3bac6 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.cs @@ -165,6 +165,14 @@ namespace Microsoft.SqlTools.ServiceLayer } } + public static string QueryServiceQueryCancelled + { + get + { + return Keys.GetString(Keys.QueryServiceQueryCancelled); + } + } + public static string QueryServiceSubsetNotCompleted { get @@ -363,6 +371,11 @@ namespace Microsoft.SqlTools.ServiceLayer return Keys.GetString(Keys.QueryServiceErrorFormat, msg, lvl, state, line, newLine, message); } + public static string QueryServiceQueryFailed(string message) + { + return Keys.GetString(Keys.QueryServiceQueryFailed, message); + } + public static string WorkspaceServicePositionColumnOutOfRange(int line) { return Keys.GetString(Keys.WorkspaceServicePositionColumnOutOfRange, line); @@ -376,7 +389,7 @@ namespace Microsoft.SqlTools.ServiceLayer [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] public class Keys { - static ResourceManager resourceManager = new ResourceManager(typeof(SR)); + static ResourceManager resourceManager = new ResourceManager("Microsoft.SqlTools.ServiceLayer.SR", typeof(SR).GetTypeInfo().Assembly); static CultureInfo _culture = null; @@ -444,6 +457,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string QueryServiceCancelDisposeFailed = "QueryServiceCancelDisposeFailed"; + public const string QueryServiceQueryCancelled = "QueryServiceQueryCancelled"; + + public const string QueryServiceSubsetNotCompleted = "QueryServiceSubsetNotCompleted"; @@ -480,6 +496,9 @@ namespace Microsoft.SqlTools.ServiceLayer public const string QueryServiceErrorFormat = "QueryServiceErrorFormat"; + public const string QueryServiceQueryFailed = "QueryServiceQueryFailed"; + + public const string QueryServiceColumnNull = "QueryServiceColumnNull"; diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.resx b/src/Microsoft.SqlTools.ServiceLayer/sr.resx index 63d7e71b..ebcaa126 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.resx +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.resx @@ -205,6 +205,10 @@ Query successfully cancelled, failed to dispose query. Owner URI not found. + + Query was canceled by user + + The query has not completed, yet @@ -254,6 +258,11 @@ Msg {0}, Level {1}, State {2}, Line {3}{4}{5} . Parameters: 0 - msg (int), 1 - lvl (int), 2 - state (int), 3 - line (int), 4 - newLine (string), 5 - message (string) + + + Query failed: {0} + . + Parameters: 0 - message (string) (No column name) diff --git a/src/Microsoft.SqlTools.ServiceLayer/sr.strings b/src/Microsoft.SqlTools.ServiceLayer/sr.strings index a74a54d9..35bca9c5 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/sr.strings +++ b/src/Microsoft.SqlTools.ServiceLayer/sr.strings @@ -79,6 +79,8 @@ QueryServiceCancelAlreadyCompleted = The query has already completed, it cannot QueryServiceCancelDisposeFailed = Query successfully cancelled, failed to dispose query. Owner URI not found. +QueryServiceQueryCancelled = Query was canceled by user + ### Subset Request QueryServiceSubsetNotCompleted = The query has not completed, yet @@ -111,6 +113,8 @@ QueryServiceCompletedSuccessfully = Command(s) copleted successfully. QueryServiceErrorFormat(int msg, int lvl, int state, int line, string newLine, string message) = Msg {0}, Level {1}, State {2}, Line {3}{4}{5} +QueryServiceQueryFailed(string message) = Query failed: {0} + QueryServiceColumnNull = (No column name) QueryServiceRequestsNoQuery = The requested query does not exist diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs index d27fe156..2ff2045d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/CancelTests.cs @@ -24,7 +24,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Set up file for returning the query var fileMock = new Mock(); fileMock.Setup(file => file.GetLinesInRange(It.IsAny())) - .Returns(new string[] { Common.StandardQuery }); + .Returns(new[] { Common.StandardQuery }); // Set up workspace mock var workspaceService = new Mock>(); workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) @@ -36,7 +36,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var executeParams = new QueryExecuteParams { QuerySelection = Common.GetSubSectionDocument(), OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Fake that it hasn't completed execution // ... And then I request to cancel the query @@ -71,13 +72,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var executeParams = new QueryExecuteParams {QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And then I request to cancel the query var cancelParams = new QueryCancelParams {OwnerUri = Common.OwnerUri}; QueryCancelResult result = null; var cancelRequest = GetQueryCancelResultContextMock(qcr => result = qcr, null); - queryService.HandleCancelRequest(cancelParams, cancelRequest.Object).Wait(); + await queryService.HandleCancelRequest(cancelParams, cancelRequest.Object); // Then: // ... I should have seen a result event with an error message diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index c3f6df75..bd5abe81 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -14,9 +14,6 @@ using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlServer.Management.Common; -using Microsoft.SqlServer.Management.SmoMetadataProvider; -using Microsoft.SqlServer.Management.SqlParser.Binder; -using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; @@ -95,7 +92,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory()); - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); return query; } @@ -287,6 +285,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution return new QueryExecutionService(connectionService, workspaceService) {BufferFileStreamFactory = GetFileStreamFactory()}; } + public static WorkspaceService GetPrimedWorkspaceService() + { + // Set up file for returning the query + var fileMock = new Mock(); + fileMock.SetupGet(file => file.Contents).Returns(StandardQuery); + + // Set up workspace mock + var workspaceService = new Mock>(); + workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) + .Returns(fileMock.Object); + + return workspaceService.Object; + } + #endregion } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs index b3ff5efd..2f103ec0 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs @@ -4,7 +4,6 @@ // using System; -using System.Data.Common; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; @@ -51,7 +50,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And then I dispose of the query var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; @@ -107,6 +107,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; var requestContext = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(queryParams, requestContext.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And it sticks around as an active query Assert.Equal(1, queryService.ActiveQueries.Count); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index 2484e233..c5694fa1 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -19,7 +19,6 @@ using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -using Microsoft.SqlTools.Test.Utility; using Moq; using Xunit; @@ -295,7 +294,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... The query should have completed successfully with one batch summary returned @@ -321,7 +321,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I Then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... The query should have completed successfully with no batch summaries returned @@ -348,7 +349,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... The query should have completed successfully with two batch summaries returned @@ -376,7 +378,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // .. I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // ... The query should have completed successfully with one batch summary returned Assert.True(query.HasExecuted); @@ -402,7 +405,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I then execute the query - query.Execute().Wait(); + query.Execute(); + query.ExecutionTask.Wait(); // Then: // ... There should be an error on the batch @@ -444,7 +448,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution expectedEvent: QueryExecuteCompleteEvent.Type, eventCallback: (et, cp) => completeParams = cp, errorCallback: null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, requestContext.Object); // Then: // ... No Errors should have been sent @@ -485,7 +489,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution expectedEvent: QueryExecuteCompleteEvent.Type, eventCallback: (et, cp) => completeParams = cp, errorCallback: null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, requestContext.Object); // Then: // ... No errors should have been sent @@ -512,18 +516,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QuerySelection = Common.WholeDocument }; - QueryExecuteResult result = null; - var requestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + object error = null; + var requestContext = RequestContextMocks.Create(null) + .AddErrorHandling(e => error = e); + await queryService.HandleExecuteRequest(queryParams, requestContext.Object); // Then: - // ... An error message should have been returned via the result + // ... An error should have been returned + // ... No result should have been returned // ... No completion event should have been fired - // ... No error event should have been fired // ... There should be no active queries - VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); - Assert.NotNull(result.Messages); - Assert.NotEmpty(result.Messages); + VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Once()); + Assert.IsType(error); + Assert.NotEmpty((string)error); Assert.Empty(queryService.ActiveQueries); } @@ -545,24 +550,25 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = Common.WholeDocument }; // Note, we don't care about the results of the first request - var firstRequestContext = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + var firstRequestContext = RequestContextMocks.Create(null); + await AwaitExecution(queryService, queryParams, firstRequestContext.Object); // ... And then I request another query without waiting for the first to complete queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Simulate query hasn't finished - QueryExecuteResult result = null; - var secondRequestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + object error = null; + var secondRequestContext = RequestContextMocks.Create(null) + .AddErrorHandling(e => error = e); + await AwaitExecution(queryService, queryParams, secondRequestContext.Object); // Then: - // ... No errors should have been sent - // ... A result should have been sent with an error message + // ... An error should have been sent + // ... A result should have not have been sent // ... No completion event should have been fired - // ... There should only be one active query - VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.AtMostOnce(), Times.Never()); - Assert.NotNull(result.Messages); - Assert.NotEmpty(result.Messages); - Assert.Equal(1, queryService.ActiveQueries.Count); + // ... The original query should exist + VerifyQueryExecuteCallCount(secondRequestContext, Times.Never(), Times.Never(), Times.Once()); + Assert.IsType(error); + Assert.NotEmpty((string)error); + Assert.Contains(Common.OwnerUri, queryService.ActiveQueries.Keys); } [Fact] @@ -584,15 +590,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Note, we don't care about the results of the first request var firstRequestContext = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - - queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, firstRequestContext.Object); // ... And then I request another query after waiting for the first to complete QueryExecuteResult result = null; QueryExecuteCompleteParams complete = null; var secondRequestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); - queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, secondRequestContext.Object); // Then: // ... No errors should have been sent @@ -606,7 +611,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Theory] [InlineData(null)] - public async void QueryExecuteMissingSelectionTest(SelectionData selection) + public async Task QueryExecuteMissingSelectionTest(SelectionData selection) { // Set up file for returning the query @@ -621,18 +626,20 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QuerySelection = selection }; - QueryExecuteResult result = null; - var requestContext = - RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + object errorResult = null; + var requestContext = RequestContextMocks.Create(null) + .AddErrorHandling(error => errorResult = error); + await queryService.HandleExecuteRequest(queryParams, requestContext.Object); // Then: - // ... No errors should have been sent - // ... A result should have been sent with an error message + // ... Am error should have been sent + // ... No result should have been sent // ... No completion event should have been fired - VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); - Assert.NotNull(result.Messages); - Assert.NotEmpty(result.Messages); + // ... An active query should not have been added + VerifyQueryExecuteCallCount(requestContext, Times.Never(), Times.Never(), Times.Once()); + Assert.NotNull(errorResult); + Assert.IsType(errorResult); + Assert.DoesNotContain(Common.OwnerUri, queryService.ActiveQueries.Keys); // ... There should not be an active query Assert.Empty(queryService.ActiveQueries); @@ -657,7 +664,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution QueryExecuteCompleteParams complete = null; var requestContext = RequestContextMocks.SetupRequestContextMock(qer => result = qer, QueryExecuteCompleteEvent.Type, (et, qecp) => complete = qecp, null); - queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + await AwaitExecution(queryService, queryParams, requestContext.Object); // Then: // ... No errors should have been sent @@ -700,7 +707,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution #endregion - private void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) + private static void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) { mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); mock.Verify(rc => rc.SendEvent( @@ -709,9 +716,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); } - private DbConnection GetConnection(ConnectionInfo info) + private static DbConnection GetConnection(ConnectionInfo info) { return info.Factory.CreateSqlConnection(ConnectionService.BuildConnectionString(info.ConnectionDetails)); } + + private static async Task AwaitExecution(QueryExecutionService service, QueryExecuteParams qeParams, + RequestContext requestContext) + { + await service.HandleExecuteRequest(qeParams, requestContext); + await service.ActiveQueries[qeParams.OwnerUri].ExecutionTask; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs index e3c38ab5..522c14a7 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs @@ -7,11 +7,10 @@ using System.IO; using System.Threading.Tasks; using System.Runtime.InteropServices; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; -using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; -using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Moq; using Xunit; @@ -28,19 +27,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsCsvSuccessTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as csv with correct parameters var saveParams = new SaveResultsAsCsvRequestParams @@ -74,20 +66,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsCsvWithSelectionSuccessTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as csv with correct parameters var saveParams = new SaveResultsAsCsvRequestParams @@ -124,21 +108,13 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution /// [Fact] public async void SaveResultsAsCsvExceptionTest() - { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - + { // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as csv with incorrect filepath var saveParams = new SaveResultsAsCsvRequestParams @@ -148,7 +124,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution BatchIndex = 0, FilePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "G:\\test.csv" : "/test.csv" }; - // SaveResultRequestResult result = null; + string errMessage = null; var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (string) err); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); @@ -166,13 +142,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsCsvQueryNotFoundTest() { - + // Create a query execution service var workspaceService = new Mock>(); - // Execute a query var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); // Request to save the results as csv with query that is no longer active var saveParams = new SaveResultsAsCsvRequestParams @@ -198,19 +170,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonSuccessTest() { - - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as json with correct parameters var saveParams = new SaveResultsAsJsonRequestParams @@ -243,19 +208,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonWithSelectionSuccessTest() { - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); - // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument , OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as json with correct parameters var saveParams = new SaveResultsAsJsonRequestParams @@ -292,18 +250,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonExceptionTest() { - // Set up file for returning the query - var fileMock = new Mock(); - fileMock.SetupGet(file => file.Contents).Returns(Common.StandardQuery); - // Set up workspace mock - var workspaceService = new Mock>(); - workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) - .Returns(fileMock.Object); // Execute a query - var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); + var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // Request to save the results as json with incorrect filepath var saveParams = new SaveResultsAsJsonRequestParams @@ -331,12 +283,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SaveResultsAsJsonQueryNotFoundTest() { + + // Create a query service var workspaceService = new Mock>(); - // Execute a query var queryService = await Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true, workspaceService.Object); - var executeParams = new QueryExecuteParams { QuerySelection = Common.WholeDocument, OwnerUri = Common.OwnerUri }; - var executeRequest = GetQueryExecuteResultContextMock(null, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); // Request to save the results as json with query that is no longer active var saveParams = new SaveResultsAsJsonRequestParams @@ -404,52 +354,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); } - /// - /// Mock request context for executing a query - /// - /// - /// - /// - /// - /// - public static Mock> GetQueryExecuteResultContextMock( - Action resultCallback, - Action, QueryExecuteCompleteParams> eventCallback, - Action errorCallback) - { - var requestContext = new Mock>(); - - // Setup the mock for SendResult - var sendResultFlow = requestContext - .Setup(rc => rc.SendResult(It.IsAny())) - .Returns(Task.FromResult(0)); - if (resultCallback != null) - { - sendResultFlow.Callback(resultCallback); - } - - // Setup the mock for SendEvent - var sendEventFlow = requestContext.Setup(rc => rc.SendEvent( - It.Is>(m => m == QueryExecuteCompleteEvent.Type), - It.IsAny())) - .Returns(Task.FromResult(0)); - if (eventCallback != null) - { - sendEventFlow.Callback(eventCallback); - } - - // Setup the mock for SendError - var sendErrorFlow = requestContext.Setup(rc => rc.SendError(It.IsAny())) - .Returns(Task.FromResult(0)); - if (errorCallback != null) - { - sendErrorFlow.Callback(errorCallback); - } - - return requestContext; - } - - #endregion } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs index 7b57971b..4036c655 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -146,8 +146,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams {QuerySelection = null, OwnerUri = Common.OwnerUri}; - var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); + var executeRequest = RequestContextMocks.Create(null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And I then ask for a valid set of results from it var subsetParams = new QueryExecuteSubsetParams {OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0}; @@ -203,8 +204,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true, workspaceService.Object); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; - var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // ... And I then ask for a valid set of results from it @@ -224,17 +226,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution [Fact] public async void SubsetServiceOutOfRangeSubsetTest() - { - - var workspaceService = new Mock>(); + { // If: // ... I have a query that doesn't have any result sets var queryService = await Common.GetPrimedExecutionService( - Common.CreateMockFactory(null, false), true, - workspaceService.Object); + Common.CreateMockFactory(null, false), true, Common.GetPrimedWorkspaceService()); var executeParams = new QueryExecuteParams { QuerySelection = null, OwnerUri = Common.OwnerUri }; - var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); - queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + var executeRequest = RequestContextMocks.Create(null); + await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); + await queryService.ActiveQueries[Common.OwnerUri].ExecutionTask; // ... And I then ask for a set of results from it var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; @@ -259,27 +259,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Action resultCallback, Action errorCallback) { - var requestContext = new Mock>(); - - // Setup the mock for SendResult - var sendResultFlow = requestContext - .Setup(rc => rc.SendResult(It.IsAny())) - .Returns(Task.FromResult(0)); - if (resultCallback != null) - { - sendResultFlow.Callback(resultCallback); - } - - // Setup the mock for SendError - var sendErrorFlow = requestContext - .Setup(rc => rc.SendError(It.IsAny())) - .Returns(Task.FromResult(0)); - if (errorCallback != null) - { - sendErrorFlow.Callback(errorCallback); - } - - return requestContext; + return RequestContextMocks.Create(resultCallback) + .AddErrorHandling(errorCallback); } private static void VerifyQuerySubsetCallCount(Mock> mock, Times sendResultCalls,