// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // // // The following is based upon code from PowerShell Editor Services // License: https://github.com/PowerShell/PowerShellEditorServices/blob/develop/LICENSE // using System; using System.Collections.Concurrent; using System.Diagnostics; using System.Threading.Tasks; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol.Channel; using Microsoft.SqlTools.Hosting.Protocol.Contracts; using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.TestDriver.Driver { /// /// Wraps the ProtocolEndpoint class with queues to handle events/requests /// public class TestDriverBase { protected ProtocolEndpoint protocolClient; protected StdioClientChannel clientChannel; private ConcurrentDictionary> eventQueuePerType = new ConcurrentDictionary>(); private ConcurrentDictionary> requestQueuePerType = new ConcurrentDictionary>(); public Process ServiceProcess { get { try { return Process.GetProcessById(clientChannel.ProcessId); } catch { return null; } } } public Task SendRequest( RequestType requestType, TParams requestParams) { return this.protocolClient.SendRequest( requestType, requestParams); } public Task SendEvent(EventType eventType, TParams eventParams) { return this.protocolClient.SendEvent( eventType, eventParams); } public void QueueEventsForType(EventType eventType) { var eventQueue = this.eventQueuePerType.AddOrUpdate( eventType.MethodName, new AsyncQueue(), (key, queue) => queue); this.protocolClient.SetEventHandler( eventType, (p, ctx) => { return eventQueue.EnqueueAsync(p); }); } public async Task WaitForEvent( EventType eventType, TimeSpan timeout) { return await WaitForEvent(eventType, (int) timeout.TotalMilliseconds); } public async Task WaitForEvent( EventType eventType, int timeoutMilliseconds = 5000) { Task eventTask = null; // Use the event queue if one has been registered AsyncQueue eventQueue = null; if (this.eventQueuePerType.TryGetValue(eventType.MethodName, out eventQueue)) { eventTask = eventQueue .DequeueAsync() .ContinueWith( task => (TParams)task.Result); } else { TaskCompletionSource eventTaskSource = new TaskCompletionSource(); this.protocolClient.SetEventHandler( eventType, (p, ctx) => { if (!eventTaskSource.Task.IsCompleted) { eventTaskSource.SetResult(p); } return Task.FromResult(true); }, true); // Override any existing handler eventTask = eventTaskSource.Task; } await Task.WhenAny( eventTask, Task.Delay(timeoutMilliseconds)); if (!eventTask.IsCompleted) { throw new TimeoutException( string.Format( "Timed out waiting for '{0}' event!", eventType.MethodName)); } return await eventTask; } public async Task>> WaitForRequest( RequestType requestType, int timeoutMilliseconds = 5000) { Task>> requestTask = null; // Use the request queue if one has been registered AsyncQueue requestQueue = null; if (this.requestQueuePerType.TryGetValue(requestType.MethodName, out requestQueue)) { requestTask = requestQueue .DequeueAsync() .ContinueWith( task => (Tuple>)task.Result); } else { var requestTaskSource = new TaskCompletionSource>>(); this.protocolClient.SetRequestHandler( requestType, (p, ctx) => { if (!requestTaskSource.Task.IsCompleted) { requestTaskSource.SetResult( new Tuple>(p, ctx)); } return Task.FromResult(true); }); requestTask = requestTaskSource.Task; } await Task.WhenAny( requestTask, Task.Delay(timeoutMilliseconds)); if (!requestTask.IsCompleted) { throw new TimeoutException( string.Format( "Timed out waiting for '{0}' request!", requestType.MethodName)); } return await requestTask; } } }