// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol.Channel; using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; namespace Microsoft.SqlTools.EditorServices.Protocol.MessageProtocol { /// /// Provides behavior for a client or server endpoint that /// communicates using the specified protocol. /// public class ProtocolEndpoint : IMessageSender { private bool isStarted; private int currentMessageId; private ChannelBase protocolChannel; private MessageProtocolType messageProtocolType; private TaskCompletionSource endpointExitedTask; private SynchronizationContext originalSynchronizationContext; private Dictionary> pendingRequests = new Dictionary>(); /// /// Gets the MessageDispatcher which allows registration of /// handlers for requests, responses, and events that are /// transmitted through the channel. /// protected MessageDispatcher MessageDispatcher { get; set; } /// /// Initializes an instance of the protocol server using the /// specified channel for communication. /// /// /// The channel to use for communication with the connected endpoint. /// /// /// The type of message protocol used by the endpoint. /// public ProtocolEndpoint( ChannelBase protocolChannel, MessageProtocolType messageProtocolType) { this.protocolChannel = protocolChannel; this.messageProtocolType = messageProtocolType; this.originalSynchronizationContext = SynchronizationContext.Current; } /// /// Starts the language server client and sends the Initialize method. /// /// A Task that can be awaited for initialization to complete. public async Task Start() { if (!this.isStarted) { // Start the provided protocol channel this.protocolChannel.Start(this.messageProtocolType); // Start the message dispatcher this.MessageDispatcher = new MessageDispatcher(this.protocolChannel); // Set the handler for any message responses that come back this.MessageDispatcher.SetResponseHandler(this.HandleResponse); // Listen for unhandled exceptions from the dispatcher this.MessageDispatcher.UnhandledException += MessageDispatcher_UnhandledException; // Notify implementation about endpoint start await this.OnStart(); // Wait for connection and notify the implementor // NOTE: This task is not meant to be awaited. Task waitTask = this.protocolChannel .WaitForConnection() .ContinueWith( async (t) => { // Start the MessageDispatcher this.MessageDispatcher.Start(); await this.OnConnect(); }); // Endpoint is now started this.isStarted = true; } } public void WaitForExit() { this.endpointExitedTask = new TaskCompletionSource(); this.endpointExitedTask.Task.Wait(); } public async Task Stop() { if (this.isStarted) { // Make sure no future calls try to stop the endpoint during shutdown this.isStarted = false; // Stop the implementation first await this.OnStop(); // Stop the dispatcher and channel this.MessageDispatcher.Stop(); this.protocolChannel.Stop(); // Notify anyone waiting for exit if (this.endpointExitedTask != null) { this.endpointExitedTask.SetResult(true); } } } #region Message Sending /// /// Sends a request to the server /// /// /// /// /// /// public Task SendRequest( RequestType requestType, TParams requestParams) { return this.SendRequest(requestType, requestParams, true); } public async Task SendRequest( RequestType requestType, TParams requestParams, bool waitForResponse) { if (!this.protocolChannel.IsConnected) { throw new InvalidOperationException("SendRequest called when ProtocolChannel was not yet connected"); } this.currentMessageId++; TaskCompletionSource responseTask = null; if (waitForResponse) { responseTask = new TaskCompletionSource(); this.pendingRequests.Add( this.currentMessageId.ToString(), responseTask); } await this.protocolChannel.MessageWriter.WriteRequest( requestType, requestParams, this.currentMessageId); if (responseTask != null) { var responseMessage = await responseTask.Task; return responseMessage.Contents != null ? responseMessage.Contents.ToObject() : default(TResult); } else { // TODO: Better default value here? return default(TResult); } } /// /// Sends an event to the channel's endpoint. /// /// The event parameter type. /// The type of event being sent. /// The event parameters being sent. /// A Task that tracks completion of the send operation. public Task SendEvent( EventType eventType, TParams eventParams) { if (!this.protocolChannel.IsConnected) { throw new InvalidOperationException("SendEvent called when ProtocolChannel was not yet connected"); } // Some events could be raised from a different thread. // To ensure that messages are written serially, dispatch // dispatch the SendEvent call to the message loop thread. if (!this.MessageDispatcher.InMessageLoopThread) { TaskCompletionSource writeTask = new TaskCompletionSource(); this.MessageDispatcher.SynchronizationContext.Post( async (obj) => { await this.protocolChannel.MessageWriter.WriteEvent( eventType, eventParams); writeTask.SetResult(true); }, null); return writeTask.Task; } else { return this.protocolChannel.MessageWriter.WriteEvent( eventType, eventParams); } } #endregion #region Message Handling public void SetRequestHandler( RequestType requestType, Func, Task> requestHandler) { this.MessageDispatcher.SetRequestHandler( requestType, requestHandler); } public void SetEventHandler( EventType eventType, Func eventHandler) { this.MessageDispatcher.SetEventHandler( eventType, eventHandler, false); } public void SetEventHandler( EventType eventType, Func eventHandler, bool overrideExisting) { this.MessageDispatcher.SetEventHandler( eventType, eventHandler, overrideExisting); } private void HandleResponse(Message responseMessage) { TaskCompletionSource pendingRequestTask = null; if (this.pendingRequests.TryGetValue(responseMessage.Id, out pendingRequestTask)) { pendingRequestTask.SetResult(responseMessage); this.pendingRequests.Remove(responseMessage.Id); } } #endregion #region Subclass Lifetime Methods protected virtual Task OnStart() { return Task.FromResult(true); } protected virtual Task OnConnect() { return Task.FromResult(true); } protected virtual Task OnStop() { return Task.FromResult(true); } #endregion #region Event Handlers private void MessageDispatcher_UnhandledException(object sender, Exception e) { if (this.endpointExitedTask != null) { this.endpointExitedTask.SetException(e); } else if (this.originalSynchronizationContext != null) { this.originalSynchronizationContext.Post(o => { throw e; }, null); } } #endregion } }