// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // #nullable disable using System; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol.Channel; using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Moq; using NUnit.Framework; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Messaging { public class MessageDispatcherTests { [Test] public void SetRequestHandlerWithOverrideTest() { RequestType requestType = RequestType.Create("test/requestType"); var dispatcher = new MessageDispatcher(new Mock().Object); dispatcher.SetRequestHandler( requestType, (i, j) => { return Task.FromResult(0); }, true); Assert.True(dispatcher.requestHandlers.Count > 0); } [Test] public void SetEventHandlerTest() { EventType eventType = EventType.Create("test/eventType"); var dispatcher = new MessageDispatcher(new Mock().Object); dispatcher.SetEventHandler( eventType, (i, j) => { return Task.FromResult(0); }); Assert.True(dispatcher.eventHandlers.Count > 0); } [Test] public void SetEventHandlerWithOverrideTest() { EventType eventType = EventType.Create("test/eventType"); var dispatcher = new MessageDispatcher(new Mock().Object); dispatcher.SetEventHandler( eventType, (i, j) => { return Task.FromResult(0); }, true); Assert.True(dispatcher.eventHandlers.Count > 0); } [Test] public void OnListenTaskCompletedFaultedTaskTest() { Task t = null; try { t = Task.Run(() => { throw new Exception(); }); t.Wait(); } catch { } finally { bool handlerCalled = false; var dispatcher = new MessageDispatcher(new Mock().Object); dispatcher.UnhandledException += (s, e) => handlerCalled = true; dispatcher.OnListenTaskCompleted(t); Assert.True(handlerCalled); } } [Test] public void ParallelMessageProcessingTest() { int numOfRequests = 10; int msForEachRequest = 1000; // Without parallel processing, this should take around numOfRequests * msForEachRequest ms to finish. // With parallel process, this should take around 1 * msForEachRequest ms to finish in theory (with perfect parallelism). // The diff should in theory be around (numOfRequests - 1) * msForEachRequest ms. // In order to make this test stable on machines with poor hardware / few logical cores, // we loose the assertion by only checking parallel process being faster than sequential processing. Assert.IsTrue(GetTimeToHandleRequests(false, numOfRequests, msForEachRequest) > GetTimeToHandleRequests(true, numOfRequests, msForEachRequest)); } /// /// Gets the time to handle certain amount of requests in ms /// /// Wheater to enable parallel processing /// num of requests to handle /// rough time taken to finish each reqeust in ms /// private long GetTimeToHandleRequests(bool parallelMessageProcessing, int numOfRequests, int msForEachRequest) { RequestType requestType = RequestType.Create("test/requestType"); var mockChannel = new Mock(); SemaphoreSlim unfinishedRequestCount = new SemaphoreSlim(numOfRequests); bool okayToEnd = false; mockChannel.Setup(c => c.MessageReader.ReadMessage()) .Returns(Task.FromResult(Message.Request("1", "test/requestType", null))); var dispatcher = new MessageDispatcher(mockChannel.Object); dispatcher.ParallelMessageProcessing = parallelMessageProcessing; Stopwatch stopwatch = Stopwatch.StartNew(); var handler = async (int _, RequestContext _) => { // simulate a slow sync call Thread.Sleep(msForEachRequest / 2); // simulate a delay async call await Task.Delay(msForEachRequest / 2); await unfinishedRequestCount.WaitAsync(); if (unfinishedRequestCount.CurrentCount == 0) { // cut off when we reach numOfRequests stopwatch.Stop(); okayToEnd = true; } await Task.CompletedTask; }; dispatcher.SetRequestHandler(requestType, handler, false, true); dispatcher.Start(); while (true) { if (okayToEnd) { // wait until we finish handling the required amount of requests break; } Thread.Sleep(1000); } dispatcher.Stop(); return stopwatch.ElapsedMilliseconds; } } }