// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.Hosting.Channels; using Microsoft.SqlTools.Hosting.Contracts; using Microsoft.SqlTools.Hosting.Contracts.Internal; using Microsoft.SqlTools.Hosting.Utility; using Microsoft.SqlTools.Hosting.v2; namespace Microsoft.SqlTools.Hosting.Protocol { public class JsonRpcHost : IJsonRpcHost { #region Private Fields internal readonly CancellationTokenSource cancellationTokenSource; private readonly CancellationToken consumeInputCancellationToken; private readonly CancellationToken consumeOutputCancellationToken; internal readonly BlockingCollection outputQueue; internal readonly Dictionary> eventHandlers; internal readonly Dictionary> requestHandlers; internal readonly ConcurrentDictionary> pendingRequests; internal readonly ChannelBase protocolChannel; internal Task consumeInputTask; internal Task consumeOutputTask; private bool isStarted; #endregion public JsonRpcHost(ChannelBase channel) { Validate.IsNotNull(nameof(channel), channel); cancellationTokenSource = new CancellationTokenSource(); consumeInputCancellationToken = cancellationTokenSource.Token; consumeOutputCancellationToken = cancellationTokenSource.Token; outputQueue = new BlockingCollection(new ConcurrentQueue()); protocolChannel = channel; eventHandlers = new Dictionary>(); requestHandlers = new Dictionary>(); pendingRequests = new ConcurrentDictionary>(); } #region Start/Stop Methods /// /// Starts the JSON RPC host using the protocol channel that was provided /// public void Start() { // If we've already started, we can't start up again if (isStarted) { throw new InvalidOperationException(SR.HostingJsonRpcHostAlreadyStarted); } // Make sure no other calls try to start the endpoint during startup isStarted = true; // Initialize the protocol channel protocolChannel.Start(); protocolChannel.WaitForConnection().Wait(); // Start the input and output consumption threads consumeInputTask = ConsumeInput(); consumeOutputTask = ConsumeOutput(); } /// /// Stops the JSON RPC host and the underlying protocol channel /// public void Stop() { // If we haven't started, we can't stop if (!isStarted) { throw new InvalidOperationException(SR.HostingJsonRpcHostNotStarted); } // Make sure no future calls try to stop the endpoint during shutdown isStarted = false; // Shutdown the host cancellationTokenSource.Cancel(); protocolChannel.Stop(); } /// /// Waits for input and output threads to naturally exit /// /// Thrown if the host has not started public void WaitForExit() { // If we haven't started everything, we can't wait for exit if (!isStarted) { throw new InvalidOperationException(SR.HostingJsonRpcHostNotStarted); } // Join the input and output threads to this thread Task.WaitAll(consumeInputTask, consumeOutputTask); } #endregion #region Public Methods /// /// Sends an event, independent of any request /// /// Event parameter type /// Type of event being sent /// Event parameters being sent /// Task that tracks completion of the send operation. public void SendEvent( EventType eventType, TParams eventParams) { if (!protocolChannel.IsConnected) { throw new InvalidOperationException("SendEvent called when ProtocolChannel was not yet connected"); } // Create a message from the event provided Message message = Message.CreateEvent(eventType, eventParams); outputQueue.Add(message); } /// /// Sends a request, independent of any request /// /// Configuration of the request that is being sent /// Contents of the request /// Type of the message contents /// Type of the contents of the expected result of the request /// Task that is completed when the /// TODO: This doesn't properly handle error responses scenarios. public async Task SendRequest( RequestType requestType, TParams requestParams) { if (!protocolChannel.IsConnected) { throw new InvalidOperationException("SendRequest called when ProtocolChannel was not yet connected"); } // Add a task completion source for the request's response string messageId = Guid.NewGuid().ToString(); TaskCompletionSource responseTask = new TaskCompletionSource(); pendingRequests.TryAdd(messageId, responseTask); // Send the request outputQueue.Add(Message.CreateRequest(requestType, messageId, requestParams)); // Wait for the response Message responseMessage = await responseTask.Task; return responseMessage.GetTypedContents(); } /// /// Sets the handler for an event with a given configuration /// /// Configuration of the event /// Function for handling the event /// Whether or not to override any existing event handler for this method /// Type of the parameters for the event public void SetAsyncEventHandler( EventType eventType, Func eventHandler, bool overrideExisting = false) { Validate.IsNotNull(nameof(eventType), eventType); Validate.IsNotNull(nameof(eventHandler), eventHandler); if (overrideExisting) { // Remove the existing handler so a new one can be set eventHandlers.Remove(eventType.MethodName); } Func handler = eventMessage => eventHandler(eventMessage.GetTypedContents(), new EventContext(outputQueue)); eventHandlers.Add(eventType.MethodName, handler); } /// /// Creates a Func based that wraps the action in a task and calls the Func-based overload /// /// Configuration of the event /// Function for handling the event /// Whether or not to override any existing event handler for this method /// Type of the parameters for the event public void SetEventHandler( EventType eventType, Action eventHandler, bool overrideExisting = false) { Validate.IsNotNull(nameof(eventHandler), eventHandler); Func eventFunc = (p, e) => Task.Run(() => eventHandler(p, e)); SetAsyncEventHandler(eventType, eventFunc, overrideExisting); } /// /// Sets the handler for a request with a given configuration /// /// Configuration of the request /// Function for handling the request /// Whether or not to override any existing request handler for this method /// Type of the parameters for the request /// Type of the parameters for the response public void SetAsyncRequestHandler( RequestType requestType, Func, Task> requestHandler, bool overrideExisting = false) { Validate.IsNotNull(nameof(requestType), requestType); Validate.IsNotNull(nameof(requestHandler), requestHandler); if (overrideExisting) { // Remove the existing handler so a new one can be set requestHandlers.Remove(requestType.MethodName); } // Setup the wrapper around the handler Func handler = requestMessage => requestHandler(requestMessage.GetTypedContents(), new RequestContext(requestMessage, outputQueue)); requestHandlers.Add(requestType.MethodName, handler); } /// /// Creates a Func based that wraps the action in a task and calls the Func-based overload /// /// /// Configuration of the request /// Function for handling the request /// Whether or not to override any existing request handler for this method /// Type of the parameters for the request /// Type of the parameters for the response public void SetRequestHandler( RequestType requestType, Action> requestHandler, bool overrideExisting = false) { Validate.IsNotNull(nameof(requestHandler), requestHandler); Func, Task> requestFunc = (p, e) => Task.Run(() => requestHandler(p, e)); SetAsyncRequestHandler(requestType, requestFunc, overrideExisting); } #endregion #region Message Processing Tasks internal Task ConsumeInput() { return Task.Factory.StartNew(async () => { while (!consumeInputCancellationToken.IsCancellationRequested) { Message incomingMessage; try { // Read message from the input channel incomingMessage = await protocolChannel.MessageReader.ReadMessage(); } catch (EndOfStreamException) { // The stream has ended, end the input message loop break; } catch (Exception e) { // Log the error and send an error event to the client string message = $"Exception occurred while receiving input message: {e.Message}"; Logger.Write(TraceEventType.Error, message); // TODO: Add event to output queue, and unit test it // Continue the loop continue; } // Verbose logging string logMessage = $"Received message with Id[{incomingMessage.Id}] of type[{incomingMessage.MessageType}] and method[{incomingMessage.Method}]"; Logger.Write(TraceEventType.Verbose, logMessage); // Process the message try { await DispatchMessage(incomingMessage); } catch (MethodHandlerDoesNotExistException) { // Method could not be handled, if the message was a request, send an error back to the client // TODO: Localize string mnfLogMessage = $"Failed to find method handler for type[{incomingMessage.MessageType}] and method[{incomingMessage.Method}]"; Logger.Write(TraceEventType.Warning, mnfLogMessage); if (incomingMessage.MessageType == MessageType.Request) { // TODO: Localize Error mnfError = new Error {Code = -32601, Message = "Method not found"}; Message errorMessage = Message.CreateResponseError(incomingMessage.Id, mnfError); outputQueue.Add(errorMessage, consumeInputCancellationToken); } } catch (Exception e) { // General errors should be logged but not halt the processing loop string geLogMessage = $"Exception thrown when handling message of type[{incomingMessage.MessageType}] and method[{incomingMessage.Method}]: {e}"; Logger.Write(TraceEventType.Error, geLogMessage); // TODO: Should we be returning a response for failing requests? } } Logger.Write(TraceEventType.Warning, "Exiting consume input loop!"); }, consumeOutputCancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default).Unwrap(); } internal Task ConsumeOutput() { return Task.Factory.StartNew(async () => { while (!consumeOutputCancellationToken.IsCancellationRequested) { Message outgoingMessage; try { // Read message from the output queue outgoingMessage = outputQueue.Take(consumeOutputCancellationToken); } catch (OperationCanceledException) { // Cancelled during taking, end the loop break; } catch (Exception e) { // If we hit an exception here, it is unrecoverable string message = string.Format("Unexpected occurred while receiving output message: {0}", e.Message); Logger.Write(TraceEventType.Error, message); break; } // Send the message string logMessage = string.Format("Sending message of type[{0}] and method[{1}]", outgoingMessage.MessageType, outgoingMessage.Method); Logger.Write(TraceEventType.Verbose, logMessage); await protocolChannel.MessageWriter.WriteMessage(outgoingMessage); } Logger.Write(TraceEventType.Warning, "Exiting consume output loop!"); }, consumeOutputCancellationToken, TaskCreationOptions.LongRunning, TaskScheduler.Default).Unwrap(); } internal async Task DispatchMessage(Message messageToDispatch) { Task handlerToAwait = null; switch (messageToDispatch.MessageType) { case MessageType.Request: Func requestHandler; if (requestHandlers.TryGetValue(messageToDispatch.Method, out requestHandler)) { handlerToAwait = requestHandler(messageToDispatch); } else { throw new MethodHandlerDoesNotExistException(MessageType.Request, messageToDispatch.Method); } break; case MessageType.Response: TaskCompletionSource requestTask; if (pendingRequests.TryRemove(messageToDispatch.Id, out requestTask)) { requestTask.SetResult(messageToDispatch); return; } else { throw new MethodHandlerDoesNotExistException(MessageType.Response, "response"); } case MessageType.Event: Func eventHandler; if (eventHandlers.TryGetValue(messageToDispatch.Method, out eventHandler)) { handlerToAwait = eventHandler(messageToDispatch); } else { throw new MethodHandlerDoesNotExistException(MessageType.Event, messageToDispatch.Method); } break; default: // TODO: This case isn't handled properly break; } // Skip processing if there isn't anything to do if (handlerToAwait == null) { return; } // Run the handler try { await handlerToAwait; } catch (TaskCanceledException) { // Some tasks may be cancelled due to legitimate // timeouts so don't let those exceptions go higher. } catch (AggregateException e) { if (!(e.InnerExceptions[0] is TaskCanceledException)) { // Cancelled tasks aren't a problem, so rethrow // anything that isn't a TaskCanceledException throw; } } } #endregion } }