// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.Hosting.Contracts; using Microsoft.SqlTools.Hosting.Protocol.Channel; using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.Hosting.Protocol { public class MessageDispatcher { #region Fields private ChannelBase protocolChannel; private AsyncContextThread messageLoopThread; internal Dictionary> requestHandlers = new Dictionary>(); internal Dictionary requestHandlerParallelismMap = new Dictionary(); internal Dictionary> eventHandlers = new Dictionary>(); internal Dictionary eventHandlerParallelismMap = new Dictionary(); private Action responseHandler; private CancellationTokenSource messageLoopCancellationToken = new CancellationTokenSource(); private SemaphoreSlim semaphore = new SemaphoreSlim(10); // Limit to 10 threads to begin with, ideally there shouldn't be any limitation #endregion #region Properties public SynchronizationContext SynchronizationContext { get; private set; } public bool InMessageLoopThread { get { // We're in the same thread as the message loop if the // current synchronization context equals the one we // know. return SynchronizationContext.Current == this.SynchronizationContext; } } protected MessageReader MessageReader { get; private set; } protected MessageWriter MessageWriter { get; private set; } /// /// Whether the message should be handled without blocking the main thread. /// public bool ParallelMessageProcessing { get; set; } #endregion #region Constructors public MessageDispatcher(ChannelBase protocolChannel) { this.protocolChannel = protocolChannel; this.MessageReader = protocolChannel.MessageReader; this.MessageWriter = protocolChannel.MessageWriter; } #endregion #region Public Methods public void Start() { // Start the main message loop thread. The Task is // not explicitly awaited because it is running on // an independent background thread. this.messageLoopThread = new AsyncContextThread("Message Dispatcher"); this.messageLoopThread .Run(() => this.ListenForMessages(this.messageLoopCancellationToken.Token)) .ContinueWith(this.OnListenTaskCompleted); } public void Stop() { // Stop the message loop thread if (this.messageLoopThread != null) { this.messageLoopCancellationToken.Cancel(); this.messageLoopThread.Stop(); } } public void SetRequestHandler( RequestType requestType, Func, Task> requestHandler) { this.SetRequestHandler( requestType, requestHandler, false); } public void SetRequestHandler( RequestType requestType, Func, Task> requestHandler, bool overrideExisting, bool isParallelProcessingSupported = false) { if (overrideExisting) { // Remove the existing handler so a new one can be set this.requestHandlers.Remove(requestType.MethodName); } this.requestHandlerParallelismMap.Add(requestType.MethodName, isParallelProcessingSupported); this.requestHandlers.Add( requestType.MethodName, async (requestMessage, messageWriter) => { Logger.Verbose($"Processing message with id[{requestMessage.Id}], of type[{requestMessage.MessageType}] and method[{requestMessage.Method}]"); var requestContext = new RequestContext( requestMessage, messageWriter); try { TParams typedParams = default(TParams); if (requestMessage.Contents != null) { try { typedParams = requestMessage.Contents.ToObject(); } catch (Exception ex) { throw new Exception($"Error parsing message contents {requestMessage.Contents}", ex); } } await requestHandler(typedParams, requestContext); Logger.Verbose($"Finished processing message with id[{requestMessage.Id}], of type[{requestMessage.MessageType}] and method[{requestMessage.Method}]"); } catch (Exception ex) { Logger.Error($"{requestType.MethodName} : {ex.GetFullErrorMessage(true)}"); await requestContext.SendError(ex.GetFullErrorMessage()); } }); } public void SetEventHandler( EventType eventType, Func eventHandler) { this.SetEventHandler( eventType, eventHandler, false); } public void SetEventHandler( EventType eventType, Func eventHandler, bool overrideExisting, bool isParallelProcessingSupported = false) { if (overrideExisting) { // Remove the existing handler so a new one can be set this.eventHandlers.Remove(eventType.MethodName); } this.eventHandlerParallelismMap.Add(eventType.MethodName, isParallelProcessingSupported); this.eventHandlers.Add( eventType.MethodName, async (eventMessage, messageWriter) => { Logger.Verbose($"Processing message with id[{eventMessage.Id}], of type[{eventMessage.MessageType}] and method[{eventMessage.Method}]"); var eventContext = new EventContext(messageWriter); TParams typedParams = default(TParams); try { if (eventMessage.Contents != null) { try { typedParams = eventMessage.Contents.ToObject(); } catch (Exception ex) { throw new Exception($"Error parsing message contents {eventMessage.Contents}", ex); } } await eventHandler(typedParams, eventContext); Logger.Verbose($"Finished processing message with id[{eventMessage.Id}], of type[{eventMessage.MessageType}] and method[{eventMessage.Method}]"); } catch (Exception ex) { // There's nothing on the client side to send an error back to so just log the error and move on Logger.Error($"{eventType.MethodName} : {ex}"); } }); } public void SetResponseHandler(Action responseHandler) { this.responseHandler = responseHandler; } #endregion #region Events public event EventHandler UnhandledException; protected void OnUnhandledException(Exception unhandledException) { if (this.UnhandledException != null) { this.UnhandledException(this, unhandledException); } } #endregion #region Private Methods private async Task ListenForMessages(CancellationToken cancellationToken) { this.SynchronizationContext = SynchronizationContext.Current; // Run the message loop while (!cancellationToken.IsCancellationRequested) { Message newMessage; try { // Read a message from the channel newMessage = await this.MessageReader.ReadMessage(); } catch (MessageParseException e) { string message = string.Format("Exception occurred while parsing message: {0}", e.Message); Logger.Error(message); await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams { Message = message }); // Continue the loop continue; } catch (EndOfStreamException) { // The stream has ended, end the message loop break; } catch (Exception e) { // Log the error and send an error event to the client string message = string.Format("Exception occurred while receiving message: {0}", e.Message); Logger.Error(message); await MessageWriter.WriteEvent(HostingErrorEvent.Type, new HostingErrorParams { Message = message }); // Continue the loop continue; } // The message could be null if there was an error parsing the // previous message. In this case, do not try to dispatch it. if (newMessage != null) { // Verbose logging string logMessage = $"Received message with id[{newMessage.Id}], of type[{newMessage.MessageType}] and method[{newMessage.Method}]"; Logger.Verbose(logMessage); // Process the message await this.DispatchMessage(newMessage, this.MessageWriter); } } } protected async Task DispatchMessage( Message messageToDispatch, MessageWriter messageWriter) { Func handlerToAwait = null; bool isParallelProcessingSupported = false; if (messageToDispatch.MessageType == MessageType.Request) { this.requestHandlers.TryGetValue(messageToDispatch.Method, out handlerToAwait); this.requestHandlerParallelismMap.TryGetValue(messageToDispatch.Method, out isParallelProcessingSupported); } else if (messageToDispatch.MessageType == MessageType.Response) { if (this.responseHandler != null) { this.responseHandler(messageToDispatch); } } else if (messageToDispatch.MessageType == MessageType.Event) { this.eventHandlers.TryGetValue(messageToDispatch.Method, out handlerToAwait); this.eventHandlerParallelismMap.TryGetValue(messageToDispatch.Method, out isParallelProcessingSupported); } // else // { // // TODO: Return message not supported // } if (handlerToAwait != null) { try { if (this.ParallelMessageProcessing && isParallelProcessingSupported) { // Run the task in a separate thread so that the main // thread is not blocked. Use semaphore to limit the degree of parallelism. await semaphore.WaitAsync(); _ = Task.Run(async () => { await handlerToAwait(messageToDispatch, messageWriter); semaphore.Release(); }); } else { await handlerToAwait(messageToDispatch, messageWriter); } } catch (TaskCanceledException e) { // Some tasks may be cancelled due to legitimate // timeouts so don't let those exceptions go higher. Logger.Verbose(string.Format("A TaskCanceledException occurred in the request handler: {0}", e.ToString())); } catch (Exception e) { if (!(e is AggregateException exception && exception.InnerExceptions[0] is TaskCanceledException)) { // Log the error but don't rethrow it to prevent any errors in the handler from crashing the service Logger.Error(string.Format("An unexpected error occurred in the request handler: {0}", e.ToString())); } } } } internal void OnListenTaskCompleted(Task listenTask) { if (listenTask.IsFaulted) { this.OnUnhandledException(listenTask.Exception); } else if (listenTask.IsCompleted || listenTask.IsCanceled) { // TODO: Dispose of anything? } } #endregion } }