Re-enable parallel message processing (#1741)

* add flag to handler

* cleanup

* concurrency control

* add flag for handler setters

* update service flags

* fix event handlers

* more handlers

* make sure behavior is unchanged if flag is off

* cleanup

* add test case for parallel processing

* comments

* stop dispatcher in test

* add log for request lifespan

* cleanup and add comments

* correctly release semaphore

* remove deleted file from merge

* use await for semaphore release

* move handler invocation to await and adjust test

* cleanup exception handling and wrapper

* space

* loose assertion condition to make test stable
This commit is contained in:
Hai Cao
2022-12-11 00:05:33 -08:00
committed by GitHub
parent c304f54ca2
commit f86ebae9b8
37 changed files with 350 additions and 245 deletions

View File

@@ -27,14 +27,22 @@ namespace Microsoft.SqlTools.Hosting.Protocol
internal Dictionary<string, Func<Message, MessageWriter, Task>> requestHandlers =
new Dictionary<string, Func<Message, MessageWriter, Task>>();
internal Dictionary<string, bool> requestHandlerParallelismMap =
new Dictionary<string, bool>();
internal Dictionary<string, Func<Message, MessageWriter, Task>> eventHandlers =
new Dictionary<string, Func<Message, MessageWriter, Task>>();
internal Dictionary<string, bool> eventHandlerParallelismMap =
new Dictionary<string, bool>();
private Action<Message> responseHandler;
private CancellationTokenSource messageLoopCancellationToken =
new CancellationTokenSource();
private SemaphoreSlim semaphore = new SemaphoreSlim(10); // Limit to 10 threads to begin with, ideally there shouldn't be any limitation
#endregion
#region Properties
@@ -112,7 +120,8 @@ namespace Microsoft.SqlTools.Hosting.Protocol
public void SetRequestHandler<TParams, TResult>(
RequestType<TParams, TResult> requestType,
Func<TParams, RequestContext<TResult>, Task> requestHandler,
bool overrideExisting)
bool overrideExisting,
bool isParallelProcessingSupported = false)
{
if (overrideExisting)
{
@@ -120,16 +129,18 @@ namespace Microsoft.SqlTools.Hosting.Protocol
this.requestHandlers.Remove(requestType.MethodName);
}
this.requestHandlerParallelismMap.Add(requestType.MethodName, isParallelProcessingSupported);
this.requestHandlers.Add(
requestType.MethodName,
async (requestMessage, messageWriter) =>
{
Logger.Write(TraceEventType.Verbose, $"Processing message with id[{requestMessage.Id}], of type[{requestMessage.MessageType}] and method[{requestMessage.Method}]");
var requestContext =
new RequestContext<TResult>(
requestMessage,
messageWriter);
try
{
{
TParams typedParams = default(TParams);
if (requestMessage.Contents != null)
{
@@ -144,6 +155,7 @@ namespace Microsoft.SqlTools.Hosting.Protocol
}
await requestHandler(typedParams, requestContext);
Logger.Write(TraceEventType.Verbose, $"Finished processing message with id[{requestMessage.Id}], of type[{requestMessage.MessageType}] and method[{requestMessage.Method}]");
}
catch (Exception ex)
{
@@ -166,7 +178,8 @@ namespace Microsoft.SqlTools.Hosting.Protocol
public void SetEventHandler<TParams>(
EventType<TParams> eventType,
Func<TParams, EventContext, Task> eventHandler,
bool overrideExisting)
bool overrideExisting,
bool isParallelProcessingSupported = false)
{
if (overrideExisting)
{
@@ -174,14 +187,16 @@ namespace Microsoft.SqlTools.Hosting.Protocol
this.eventHandlers.Remove(eventType.MethodName);
}
this.eventHandlerParallelismMap.Add(eventType.MethodName, isParallelProcessingSupported);
this.eventHandlers.Add(
eventType.MethodName,
async (eventMessage, messageWriter) =>
{
Logger.Write(TraceEventType.Verbose, $"Processing message with id[{eventMessage.Id}], of type[{eventMessage.MessageType}] and method[{eventMessage.Method}]");
var eventContext = new EventContext(messageWriter);
TParams typedParams = default(TParams);
try
{
{
if (eventMessage.Contents != null)
{
try
@@ -194,6 +209,7 @@ namespace Microsoft.SqlTools.Hosting.Protocol
}
}
await eventHandler(typedParams, eventContext);
Logger.Write(TraceEventType.Verbose, $"Finished processing message with id[{eventMessage.Id}], of type[{eventMessage.MessageType}] and method[{eventMessage.Method}]");
}
catch (Exception ex)
{
@@ -284,19 +300,13 @@ namespace Microsoft.SqlTools.Hosting.Protocol
Message messageToDispatch,
MessageWriter messageWriter)
{
Task handlerToAwait = null;
Func<Message, MessageWriter, Task> handlerToAwait = null;
bool isParallelProcessingSupported = false;
if (messageToDispatch.MessageType == MessageType.Request)
{
Func<Message, MessageWriter, Task> requestHandler = null;
if (this.requestHandlers.TryGetValue(messageToDispatch.Method, out requestHandler))
{
handlerToAwait = requestHandler(messageToDispatch, messageWriter);
}
// else
// {
// // TODO: Message not supported error
// }
this.requestHandlers.TryGetValue(messageToDispatch.Method, out handlerToAwait);
this.requestHandlerParallelismMap.TryGetValue(messageToDispatch.Method, out isParallelProcessingSupported);
}
else if (messageToDispatch.MessageType == MessageType.Response)
{
@@ -307,15 +317,8 @@ namespace Microsoft.SqlTools.Hosting.Protocol
}
else if (messageToDispatch.MessageType == MessageType.Event)
{
Func<Message, MessageWriter, Task> eventHandler = null;
if (this.eventHandlers.TryGetValue(messageToDispatch.Method, out eventHandler))
{
handlerToAwait = eventHandler(messageToDispatch, messageWriter);
}
else
{
// TODO: Message not supported error
}
this.eventHandlers.TryGetValue(messageToDispatch.Method, out handlerToAwait);
this.eventHandlerParallelismMap.TryGetValue(messageToDispatch.Method, out isParallelProcessingSupported);
}
// else
// {
@@ -324,39 +327,37 @@ namespace Microsoft.SqlTools.Hosting.Protocol
if (handlerToAwait != null)
{
if (this.ParallelMessageProcessing)
try
{
// Run the task in a separate thread so that the main
// thread is not blocked.
_ = Task.Run(() =>
if (this.ParallelMessageProcessing && isParallelProcessingSupported)
{
_ = RunTask(handlerToAwait);
});
// Run the task in a separate thread so that the main
// thread is not blocked. Use semaphore to limit the degree of parallelism.
await semaphore.WaitAsync();
_ = Task.Run(async () =>
{
await handlerToAwait(messageToDispatch, messageWriter);
semaphore.Release();
});
}
else
{
await handlerToAwait(messageToDispatch, messageWriter);
}
}
else
catch (TaskCanceledException e)
{
await RunTask(handlerToAwait);
// Some tasks may be cancelled due to legitimate
// timeouts so don't let those exceptions go higher.
Logger.Write(TraceEventType.Verbose, string.Format("A TaskCanceledException occurred in the request handler: {0}", e.ToString()));
}
}
}
private async Task RunTask(Task task)
{
try
{
await task;
}
catch (TaskCanceledException)
{
// Some tasks may be cancelled due to legitimate
// timeouts so don't let those exceptions go higher.
}
catch (Exception e)
{
if (!(e is AggregateException && ((AggregateException)e).InnerExceptions[0] is TaskCanceledException))
catch (Exception e)
{
// Log the error but don't rethrow it to prevent any errors in the handler from crashing the service
Logger.Write(TraceEventType.Error, string.Format("An unexpected error occured in the request handler: {0}", e.ToString()));
if (!(e is AggregateException && ((AggregateException)e).InnerExceptions[0] is TaskCanceledException))
{
// Log the error but don't rethrow it to prevent any errors in the handler from crashing the service
Logger.Write(TraceEventType.Error, string.Format("An unexpected error occurred in the request handler: {0}", e.ToString()));
}
}
}
}