diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IEventSender.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IEventSender.cs new file mode 100644 index 00000000..f625a88e --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IEventSender.cs @@ -0,0 +1,15 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol +{ + public interface IEventSender + { + Task SendEvent(EventType eventType, TParams eventParams); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IProtocolEndpoint.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IProtocolEndpoint.cs index 496e3d56..0450b124 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IProtocolEndpoint.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IProtocolEndpoint.cs @@ -14,7 +14,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol /// respond to requests and events, send their own requests, and listen for notifications /// sent by the other side of the endpoint /// - public interface IProtocolEndpoint : IMessageSender + public interface IProtocolEndpoint : IEventSender, IRequestSender { void SetRequestHandler( RequestType requestType, diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IMessageSender.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IRequestSender.cs similarity index 55% rename from src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IMessageSender.cs rename to src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IRequestSender.cs index 583fb3b0..ca69f5b3 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IMessageSender.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/IRequestSender.cs @@ -1,4 +1,4 @@ -// +// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // @@ -8,16 +8,9 @@ using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { - public interface IMessageSender + public interface IRequestSender { - Task SendEvent( - EventType eventType, - TParams eventParams); - - Task SendRequest( - RequestType requestType, - TParams requestParams, + Task SendRequest(RequestType requestType, TParams requestParams, bool waitForResponse); } } - diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/ProtocolEndpoint.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/ProtocolEndpoint.cs index aaa4b7db..2a0652ac 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/ProtocolEndpoint.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/ProtocolEndpoint.cs @@ -17,7 +17,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol /// Provides behavior for a client or server endpoint that /// communicates using the specified protocol. /// - public class ProtocolEndpoint : IMessageSender, IProtocolEndpoint + public class ProtocolEndpoint : IProtocolEndpoint { private bool isStarted; private int currentMessageId; diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs index a2811f6a..ff2f4717 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs @@ -9,10 +9,10 @@ using Newtonsoft.Json.Linq; namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol { - public class RequestContext + public class RequestContext : IEventSender { - private Message requestMessage; - private MessageWriter messageWriter; + private readonly Message requestMessage; + private readonly MessageWriter messageWriter; public RequestContext(Message requestMessage, MessageWriter messageWriter) { @@ -24,7 +24,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol public virtual async Task SendResult(TResult resultDetails) { - await this.messageWriter.WriteResponse( + await this.messageWriter.WriteResponse( resultDetails, requestMessage.Method, requestMessage.Id); diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs index 70e6631c..d242d714 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs @@ -20,11 +20,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// public class QueryDisposeResult { - /// - /// Any error messages that occurred during disposing the result set. Optional, can be set - /// to null if there were no errors. - /// - public string Messages { get; set; } } public class QueryDisposeRequest diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 4db65d8e..999c126d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -87,17 +87,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// The collection of active queries /// - internal ConcurrentDictionary ActiveQueries - { - get { return queries.Value; } - } + internal ConcurrentDictionary ActiveQueries => queries.Value; /// /// Instance of the connection service, used to get the connection info for a given owner URI /// - private ConnectionService ConnectionService { get; set; } + private ConnectionService ConnectionService { get; } - private WorkspaceService WorkspaceService { get; set; } + private WorkspaceService WorkspaceService { get; } /// /// Internal storage of active queries, lazily constructed as a threadsafe dictionary @@ -105,7 +102,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution private readonly Lazy> queries = new Lazy>(() => new ConcurrentDictionary()); - private SqlToolsSettings Settings { get { return WorkspaceService.Instance.CurrentSettings; } } + private SqlToolsSettings Settings => WorkspaceService.Instance.CurrentSettings; #endregion @@ -146,20 +143,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Handles request to execute a selection of a document in the workspace service /// - public async Task HandleExecuteRequest(ExecuteRequestParamsBase executeDocumentSelectionParams, + internal Task HandleExecuteRequest(ExecuteRequestParamsBase executeParams, RequestContext requestContext) { - // Get a query new active query - Query newQuery = await CreateAndActivateNewQuery(executeDocumentSelectionParams, requestContext); + // Setup actions to perform upon successful start and on failure to start + Func queryCreationAction = () => requestContext.SendResult(new ExecuteRequestResult()); + Func queryFailAction = requestContext.SendError; - // Execute the query -- asynchronously - ExecuteAndCompleteQuery(executeDocumentSelectionParams, requestContext, newQuery); + // Use the internal handler to launch the query + return InterServiceExecuteQuery(executeParams, requestContext, queryCreationAction, queryFailAction); } /// /// Handles a request to get a subset of the results of this query /// - public async Task HandleResultSubsetRequest(SubsetParams subsetParams, + internal async Task HandleResultSubsetRequest(SubsetParams subsetParams, RequestContext requestContext) { try @@ -210,7 +208,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Handles a request to get an execution plan /// - public async Task HandleExecutionPlanRequest(QueryExecutionPlanParams planParams, + internal async Task HandleExecutionPlanRequest(QueryExecutionPlanParams planParams, RequestContext requestContext) { try @@ -240,41 +238,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Handles a request to dispose of this query /// - public async Task HandleDisposeRequest(QueryDisposeParams disposeParams, + internal async Task HandleDisposeRequest(QueryDisposeParams disposeParams, RequestContext requestContext) { - try - { - // Attempt to remove the query for the owner uri - Query result; - if (!ActiveQueries.TryRemove(disposeParams.OwnerUri, out result)) - { - await requestContext.SendResult(new QueryDisposeResult - { - Messages = SR.QueryServiceRequestsNoQuery - }); - return; - } + // Setup action for success and failure + Func successAction = () => requestContext.SendResult(new QueryDisposeResult()); + Func failureAction = requestContext.SendError; - // Cleanup the query - result.Dispose(); - - // Success - await requestContext.SendResult(new QueryDisposeResult - { - Messages = null - }); - } - catch (Exception e) - { - await requestContext.SendError(e.Message); - } + // Use the inter-service dispose functionality + await InterServiceDisposeQuery(disposeParams.OwnerUri, successAction, failureAction); } /// /// Handles a request to cancel this query if it is in progress /// - public async Task HandleCancelRequest(QueryCancelParams cancelParams, + internal async Task HandleCancelRequest(QueryCancelParams cancelParams, RequestContext requestContext) { try @@ -338,9 +316,75 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #endregion + #region Inter-Service API Handlers + + /// + /// Query execution meant to be called from another service. Utilizes callbacks to allow + /// custom actions to be taken upon creation of query and failure to create query. + /// + /// Params for creating the new query + /// Object that can send events for query execution progress + /// + /// Action to perform when query has been successfully created, right before execution of + /// the query + /// + /// Action to perform if query was not successfully created + public async Task InterServiceExecuteQuery(ExecuteRequestParamsBase executeParams, IEventSender eventSender, + Func queryCreatedAction, Func failureAction) + { + Validate.IsNotNull(nameof(executeParams), executeParams); + Validate.IsNotNull(nameof(eventSender), eventSender); + Validate.IsNotNull(nameof(queryCreatedAction), queryCreatedAction); + Validate.IsNotNull(nameof(failureAction), failureAction); + + // Get a new active query + Query newQuery = await CreateAndActivateNewQuery(executeParams, queryCreatedAction, failureAction); + + // Execute the query asynchronously + ExecuteAndCompleteQuery(executeParams.OwnerUri, eventSender, newQuery); + } + + /// + /// Query disposal meant to be called from another service. Utilizes callbacks to allow + /// custom actions to be performed on success or failure. + /// + /// The identifier of the query to be disposed + /// Action to perform on success + /// Action to perform on failure + /// + public async Task InterServiceDisposeQuery(string ownerUri, Func successAction, + Func failureAction) + { + Validate.IsNotNull(nameof(successAction), successAction); + Validate.IsNotNull(nameof(failureAction), failureAction); + + try + { + // Attempt to remove the query for the owner uri + Query result; + if (!ActiveQueries.TryRemove(ownerUri, out result)) + { + await failureAction(SR.QueryServiceRequestsNoQuery); + return; + } + + // Cleanup the query + result.Dispose(); + + // Success + await successAction(); + } + catch (Exception e) + { + await failureAction(e.Message); + } + } + + #endregion + #region Private Helpers - private async Task CreateAndActivateNewQuery(ExecuteRequestParamsBase executeParams, RequestContext requestContext) + private async Task CreateAndActivateNewQuery(ExecuteRequestParamsBase executeParams, Func successAction, Func failureAction) { try { @@ -348,7 +392,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution ConnectionInfo connectionInfo; if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connectionInfo)) { - await requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri); + await failureAction(SR.QueryServiceQueryInvalidOwnerUri); return null; } @@ -370,24 +414,24 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution Query newQuery = new Query(GetSqlText(executeParams), connectionInfo, settings, BufferFileFactory); if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) { - await requestContext.SendError(SR.QueryServiceQueryInProgress); + await failureAction(SR.QueryServiceQueryInProgress); newQuery.Dispose(); return null; } - // Send the result stating that the query was successfully started - await requestContext.SendResult(new ExecuteRequestResult()); + // Successfully created query + await successAction(); return newQuery; } catch (Exception e) { - await requestContext.SendError(e.Message); + await failureAction(e.Message); return null; } } - private static void ExecuteAndCompleteQuery(ExecuteRequestParamsBase executeDocumentSelectionParams, RequestContext requestContext, Query query) + private static void ExecuteAndCompleteQuery(string ownerUri, IEventSender eventSender, Query query) { // Skip processing if the query is null if (query == null) @@ -401,11 +445,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Send back the results QueryCompleteParams eventParams = new QueryCompleteParams { - OwnerUri = executeDocumentSelectionParams.OwnerUri, + OwnerUri = ownerUri, BatchSummaries = q.BatchSummaries }; - await requestContext.SendEvent(QueryCompleteEvent.Type, eventParams); + await eventSender.SendEvent(QueryCompleteEvent.Type, eventParams); }; Query.QueryAsyncErrorEventHandler errorCallback = async errorMessage => @@ -413,10 +457,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Send back the error message QueryCompleteParams eventParams = new QueryCompleteParams { - OwnerUri = executeDocumentSelectionParams.OwnerUri, + OwnerUri = ownerUri, //Message = errorMessage }; - await requestContext.SendEvent(QueryCompleteEvent.Type, eventParams); + await eventSender.SendEvent(QueryCompleteEvent.Type, eventParams); }; query.QueryCompleted += callback; @@ -429,10 +473,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution BatchEventParams eventParams = new BatchEventParams { BatchSummary = b.Summary, - OwnerUri = executeDocumentSelectionParams.OwnerUri + OwnerUri = ownerUri }; - await requestContext.SendEvent(BatchStartEvent.Type, eventParams); + await eventSender.SendEvent(BatchStartEvent.Type, eventParams); }; query.BatchStarted += batchStartCallback; @@ -441,10 +485,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution BatchEventParams eventParams = new BatchEventParams { BatchSummary = b.Summary, - OwnerUri = executeDocumentSelectionParams.OwnerUri + OwnerUri = ownerUri }; - await requestContext.SendEvent(BatchCompleteEvent.Type, eventParams); + await eventSender.SendEvent(BatchCompleteEvent.Type, eventParams); }; query.BatchCompleted += batchCompleteCallback; @@ -453,9 +497,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution MessageParams eventParams = new MessageParams { Message = m, - OwnerUri = executeDocumentSelectionParams.OwnerUri + OwnerUri = ownerUri }; - await requestContext.SendEvent(MessageEvent.Type, eventParams); + await eventSender.SendEvent(MessageEvent.Type, eventParams); }; query.BatchMessageSent += batchMessageCallback; @@ -465,9 +509,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution ResultSetEventParams eventParams = new ResultSetEventParams { ResultSetSummary = r.Summary, - OwnerUri = executeDocumentSelectionParams.OwnerUri + OwnerUri = ownerUri }; - await requestContext.SendEvent(ResultSetCompleteEvent.Type, eventParams); + await eventSender.SendEvent(ResultSetCompleteEvent.Type, eventParams); }; query.ResultSetCompleted += resultCallback; diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs index 43c4115c..28a286c2 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs @@ -48,11 +48,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... And then I dispose of the query var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; var disposeRequest = new EventFlowValidator() - .AddResultValidation(r => - { - // Then: Messages should be null - Assert.Null(r.Messages); - }).Complete(); + .AddStandardQueryDisposeValidator() + .Complete(); await queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object); // Then: @@ -71,13 +68,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; var disposeRequest = new EventFlowValidator() - .AddResultValidation(r => - { - // Then: Messages should not be null - Assert.NotNull(r.Messages); - Assert.NotEmpty(r.Messages); - }).Complete(); + .AddErrorValidation(Assert.NotEmpty) + .Complete(); await queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object); + + // Then: I should have received an error disposeRequest.Validate(); } @@ -107,4 +102,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Assert.Empty(queryService.ActiveQueries); } } + + public static class QueryDisposeEventFlowValidatorExtensions + { + public static EventFlowValidator AddStandardQueryDisposeValidator( + this EventFlowValidator evf) + { + // We just need to make sure that the result is not null + evf.AddResultValidation(Assert.NotNull); + + return evf; + } + } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs index 10c22bf5..c6b5721b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Execution/ServiceIntegrationTests.cs @@ -88,6 +88,97 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution #endregion + #region Inter-Service API Tests + + [Fact] + public async Task InterServiceExecuteNullExecuteParams() + { + // Setup: Create a query service + var qes = new QueryExecutionService(null, null); + var eventSender = new EventFlowValidator().Complete().Object; + Func successFunc = () => Task.FromResult(0); + Func errorFunc = Task.FromResult; + + + // If: I call the inter-service API to execute with a null execute params + // Then: It should throw + await Assert.ThrowsAsync( + () => qes.InterServiceExecuteQuery(null, eventSender, successFunc, errorFunc)); + } + + [Fact] + public async Task InterServiceExecuteNullEventSender() + { + // Setup: Create a query service, and execute params + var qes = new QueryExecutionService(null, null); + var executeParams = new ExecuteStringParams(); + Func successFunc = () => Task.FromResult(0); + Func errorFunc = Task.FromResult; + + // If: I call the inter-service API to execute a query with a a null event sender + // Then: It should throw + await Assert.ThrowsAsync( + () => qes.InterServiceExecuteQuery(executeParams, null, successFunc, errorFunc)); + } + + [Fact] + public async Task InterServiceExecuteNullSuccessFunc() + { + // Setup: Create a query service, and execute params + var qes = new QueryExecutionService(null, null); + var executeParams = new ExecuteStringParams(); + var eventSender = new EventFlowValidator().Complete().Object; + Func errorFunc = Task.FromResult; + + // If: I call the inter-service API to execute a query with a a null success function + // Then: It should throw + await Assert.ThrowsAsync( + () => qes.InterServiceExecuteQuery(executeParams, eventSender, null, errorFunc)); + } + + [Fact] + public async Task InterServiceExecuteNullFailureFunc() + { + // Setup: Create a query service, and execute params + var qes = new QueryExecutionService(null, null); + var executeParams = new ExecuteStringParams(); + var eventSender = new EventFlowValidator().Complete().Object; + Func successFunc = () => Task.FromResult(0); + + // If: I call the inter-service API to execute a query with a a null failure function + // Then: It should throw + await Assert.ThrowsAsync( + () => qes.InterServiceExecuteQuery(executeParams, eventSender, successFunc, null)); + } + + [Fact] + public async Task InterServiceDisposeNullSuccessFunc() + { + // Setup: Create a query service and dispose params + var qes = new QueryExecutionService(null, null); + Func failureFunc = Task.FromResult; + + // If: I call the inter-service API to dispose a query with a null success function + // Then: It should throw + await Assert.ThrowsAsync( + () => qes.InterServiceDisposeQuery(Common.OwnerUri, null, failureFunc)); + } + + [Fact] + public async Task InterServiceDisposeNullFailureFunc() + { + // Setup: Create a query service and dispose params + var qes = new QueryExecutionService(null, null); + Func successFunc = () => Task.FromResult(0); + + // If: I call the inter-service API to dispose a query with a null success function + // Then: It should throw + await Assert.ThrowsAsync( + () => qes.InterServiceDisposeQuery(Common.OwnerUri, successFunc, null)); + } + + #endregion + #region Execution Tests // NOTE: In order to limit test duplication, we're running the ExecuteDocumentSelection // version of execute query. The code paths are almost identical. @@ -378,7 +469,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.Execution } } - public static class EventFlowValidatorExtensions + public static class QueryExecutionEventFlowValidatorExtensions { public static EventFlowValidator AddStandardQueryResultValidator( this EventFlowValidator efv)