File Browser: Adding async task exception handling (#504)

* Replacing Dictionary with ConcurrentDictionary since values are accessed in async contexts

* Adding new method to allow async tasks to be executed in the exception continuation

* Adding unit tests for the aforementioned

* Adding exception handling to async tasks in file browser service

* Updating query execution async handling to use the async version

* Removing unnecesary send result from continuewithonfaulted
This commit is contained in:
Benjamin Russell
2017-10-19 11:25:29 -07:00
committed by GitHub
parent 4b66203dfc
commit 9600125186
5 changed files with 220 additions and 107 deletions

View File

@@ -4,8 +4,7 @@
// //
using System; using System;
using System.Collections.Generic; using System.Collections.Concurrent;
using System.Data;
using System.Data.Common; using System.Data.Common;
using System.Data.SqlClient; using System.Data.SqlClient;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -14,6 +13,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection;
using Microsoft.SqlTools.ServiceLayer.FileBrowser.Contracts; using Microsoft.SqlTools.ServiceLayer.FileBrowser.Contracts;
using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.FileBrowser namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
{ {
@@ -26,9 +26,9 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
public static FileBrowserService Instance => LazyInstance.Value; public static FileBrowserService Instance => LazyInstance.Value;
// Cache file browser operations for expanding node request // Cache file browser operations for expanding node request
private Dictionary<string, FileBrowserOperation> ownerToFileBrowserMap = new Dictionary<string, FileBrowserOperation>(); private readonly ConcurrentDictionary<string, FileBrowserOperation> ownerToFileBrowserMap = new ConcurrentDictionary<string, FileBrowserOperation>();
private Dictionary<string, ValidatePathsCallback> validatePathsCallbackMap = new Dictionary<string, ValidatePathsCallback>(); private readonly ConcurrentDictionary<string, ValidatePathsCallback> validatePathsCallbackMap = new ConcurrentDictionary<string, ValidatePathsCallback>();
private ConnectionService connectionService = null; private ConnectionService connectionService;
/// <summary> /// <summary>
/// Signature for callback method that validates the selected file paths /// Signature for callback method that validates the selected file paths
@@ -52,23 +52,6 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
} }
} }
/// <summary>
/// Service host object for sending/receiving requests/events.
/// Internal for testing purposes.
/// </summary>
internal IProtocolEndpoint ServiceHost
{
get;
set;
}
/// <summary>
/// Constructor
/// </summary>
public FileBrowserService()
{
}
/// <summary> /// <summary>
/// Register validate path callback /// Register validate path callback
/// </summary> /// </summary>
@@ -76,23 +59,15 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
/// <param name="callback"></param> /// <param name="callback"></param>
public void RegisterValidatePathsCallback(string service, ValidatePathsCallback callback) public void RegisterValidatePathsCallback(string service, ValidatePathsCallback callback)
{ {
if (this.validatePathsCallbackMap.ContainsKey(service)) validatePathsCallbackMap.AddOrUpdate(service, callback, (key, oldValue) => callback);
{
this.validatePathsCallbackMap.Remove(service);
}
this.validatePathsCallbackMap.Add(service, callback);
} }
/// <summary> /// <summary>
/// Initializes the service instance /// Initializes the service instance
/// </summary> /// </summary>
/// <param name="serviceHost"></param> /// <param name="serviceHost">Service host to register handlers with</param>
/// <param name="context"></param>
public void InitializeService(ServiceHost serviceHost) public void InitializeService(ServiceHost serviceHost)
{ {
this.ServiceHost = serviceHost;
// Open a file browser // Open a file browser
serviceHost.SetRequestHandler(FileBrowserOpenRequest.Type, HandleFileBrowserOpenRequest); serviceHost.SetRequestHandler(FileBrowserOpenRequest.Type, HandleFileBrowserOpenRequest);
@@ -108,13 +83,12 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
#region request handlers #region request handlers
internal async Task HandleFileBrowserOpenRequest( internal async Task HandleFileBrowserOpenRequest(FileBrowserOpenParams fileBrowserParams, RequestContext<bool> requestContext)
FileBrowserOpenParams fileBrowserParams,
RequestContext<bool> requestContext)
{ {
try try
{ {
var task = Task.Run(() => RunFileBrowserOpenTask(fileBrowserParams)); var task = Task.Run(() => RunFileBrowserOpenTask(fileBrowserParams, requestContext))
.ContinueWithOnFaulted(null);
await requestContext.SendResult(true); await requestContext.SendResult(true);
} }
catch catch
@@ -123,13 +97,12 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
} }
} }
internal async Task HandleFileBrowserExpandRequest( internal async Task HandleFileBrowserExpandRequest(FileBrowserExpandParams fileBrowserParams, RequestContext<bool> requestContext)
FileBrowserExpandParams fileBrowserParams,
RequestContext<bool> requestContext)
{ {
try try
{ {
var task = Task.Run(() => RunFileBrowserExpandTask(fileBrowserParams)); var task = Task.Run(() => RunFileBrowserExpandTask(fileBrowserParams, requestContext))
.ContinueWithOnFaulted(null);
await requestContext.SendResult(true); await requestContext.SendResult(true);
} }
catch catch
@@ -138,13 +111,12 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
} }
} }
internal async Task HandleFileBrowserValidateRequest( internal async Task HandleFileBrowserValidateRequest(FileBrowserValidateParams fileBrowserParams, RequestContext<bool> requestContext)
FileBrowserValidateParams fileBrowserParams,
RequestContext<bool> requestContext)
{ {
try try
{ {
var task = Task.Run(() => RunFileBrowserValidateTask(fileBrowserParams)); var task = Task.Run(() => RunFileBrowserValidateTask(fileBrowserParams, requestContext))
.ContinueWithOnFaulted(null);
await requestContext.SendResult(true); await requestContext.SendResult(true);
} }
catch catch
@@ -158,22 +130,15 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
RequestContext<FileBrowserCloseResponse> requestContext) RequestContext<FileBrowserCloseResponse> requestContext)
{ {
FileBrowserCloseResponse response = new FileBrowserCloseResponse(); FileBrowserCloseResponse response = new FileBrowserCloseResponse();
if (this.ownerToFileBrowserMap.ContainsKey(fileBrowserParams.OwnerUri)) FileBrowserOperation removedOperation;
{ response.Succeeded = ownerToFileBrowserMap.TryRemove(fileBrowserParams.OwnerUri, out removedOperation);
this.ownerToFileBrowserMap.Remove(fileBrowserParams.OwnerUri);
response.Succeeded = true;
}
else
{
response.Succeeded = false;
}
await requestContext.SendResult(response); await requestContext.SendResult(response);
} }
#endregion #endregion
internal async Task RunFileBrowserOpenTask(FileBrowserOpenParams fileBrowserParams) internal async Task RunFileBrowserOpenTask(FileBrowserOpenParams fileBrowserParams, RequestContext<bool> requestContext)
{ {
FileBrowserOpenedParams result = new FileBrowserOpenedParams(); FileBrowserOpenedParams result = new FileBrowserOpenedParams();
@@ -189,7 +154,7 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
connInfo.TryGetConnection(ConnectionType.Default, out dbConn); connInfo.TryGetConnection(ConnectionType.Default, out dbConn);
if (dbConn != null) if (dbConn != null)
{ {
conn = ReliableConnectionHelper.GetAsSqlConnection((IDbConnection)dbConn); conn = ReliableConnectionHelper.GetAsSqlConnection(dbConn);
} }
} }
@@ -198,11 +163,7 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
FileBrowserOperation browser = new FileBrowserOperation(conn, fileBrowserParams.ExpandPath, fileBrowserParams.FileFilters); FileBrowserOperation browser = new FileBrowserOperation(conn, fileBrowserParams.ExpandPath, fileBrowserParams.FileFilters);
browser.PopulateFileTree(); browser.PopulateFileTree();
if (this.ownerToFileBrowserMap.ContainsKey(fileBrowserParams.OwnerUri)) ownerToFileBrowserMap.AddOrUpdate(fileBrowserParams.OwnerUri, browser, (key, value) => browser);
{
this.ownerToFileBrowserMap.Remove(fileBrowserParams.OwnerUri);
}
this.ownerToFileBrowserMap.Add(fileBrowserParams.OwnerUri, browser);
result.OwnerUri = fileBrowserParams.OwnerUri; result.OwnerUri = fileBrowserParams.OwnerUri;
result.FileTree = browser.FileTree; result.FileTree = browser.FileTree;
@@ -219,25 +180,21 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
result.Message = ex.Message; result.Message = ex.Message;
} }
await ServiceHost.SendEvent(FileBrowserOpenedNotification.Type, result); await requestContext.SendEvent(FileBrowserOpenedNotification.Type, result);
} }
internal async Task RunFileBrowserExpandTask(FileBrowserExpandParams fileBrowserParams) internal async Task RunFileBrowserExpandTask(FileBrowserExpandParams fileBrowserParams, RequestContext<bool> requestContext)
{ {
FileBrowserExpandedParams result = new FileBrowserExpandedParams(); FileBrowserExpandedParams result = new FileBrowserExpandedParams();
try try
{ {
if (this.ownerToFileBrowserMap.ContainsKey(fileBrowserParams.OwnerUri)) FileBrowserOperation browser;
result.Succeeded = ownerToFileBrowserMap.TryGetValue(fileBrowserParams.OwnerUri, out browser);
if (result.Succeeded && browser != null)
{ {
FileBrowserOperation browser = this.ownerToFileBrowserMap[fileBrowserParams.OwnerUri];
result.Children = browser.GetChildren(fileBrowserParams.ExpandPath).ToArray(); result.Children = browser.GetChildren(fileBrowserParams.ExpandPath).ToArray();
result.ExpandPath = fileBrowserParams.ExpandPath; result.ExpandPath = fileBrowserParams.ExpandPath;
result.OwnerUri = fileBrowserParams.OwnerUri; result.OwnerUri = fileBrowserParams.OwnerUri;
result.Succeeded = true;
}
else
{
result.Succeeded = false;
} }
} }
catch (Exception ex) catch (Exception ex)
@@ -246,22 +203,23 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
result.Message = ex.Message; result.Message = ex.Message;
} }
await ServiceHost.SendEvent(FileBrowserExpandedNotification.Type, result); await requestContext.SendEvent(FileBrowserExpandedNotification.Type, result);
} }
internal async Task RunFileBrowserValidateTask(FileBrowserValidateParams fileBrowserParams) internal async Task RunFileBrowserValidateTask(FileBrowserValidateParams fileBrowserParams, RequestContext<bool> requestContext)
{ {
FileBrowserValidatedParams result = new FileBrowserValidatedParams(); FileBrowserValidatedParams result = new FileBrowserValidatedParams();
try try
{ {
if (this.validatePathsCallbackMap.ContainsKey(fileBrowserParams.ServiceType) ValidatePathsCallback callback;
&& this.validatePathsCallbackMap[fileBrowserParams.ServiceType] != null if (validatePathsCallbackMap.TryGetValue(fileBrowserParams.ServiceType, out callback)
&& callback != null
&& fileBrowserParams.SelectedFiles != null && fileBrowserParams.SelectedFiles != null
&& fileBrowserParams.SelectedFiles.Length > 0) && fileBrowserParams.SelectedFiles.Length > 0)
{ {
string errorMessage; string errorMessage;
result.Succeeded = this.validatePathsCallbackMap[fileBrowserParams.ServiceType](new FileBrowserValidateEventArgs result.Succeeded = callback(new FileBrowserValidateEventArgs
{ {
ServiceType = fileBrowserParams.ServiceType, ServiceType = fileBrowserParams.ServiceType,
OwnerUri = fileBrowserParams.OwnerUri, OwnerUri = fileBrowserParams.OwnerUri,
@@ -284,7 +242,7 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser
result.Message = ex.Message; result.Message = ex.Message;
} }
await ServiceHost.SendEvent(FileBrowserValidatedNotification.Type, result); await requestContext.SendEvent(FileBrowserValidatedNotification.Type, result);
} }
} }
} }

View File

@@ -277,9 +277,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
public void Execute() public void Execute()
{ {
ExecutionTask = Task.Run(ExecuteInternal) ExecutionTask = Task.Run(ExecuteInternal)
.ContinueWithOnFaulted(t => .ContinueWithOnFaulted(async t =>
{ {
QueryFailed?.Invoke(this, t.Exception).Wait(); if (QueryFailed != null)
{
await QueryFailed(this, t.Exception);
}
}); });
} }

View File

@@ -136,12 +136,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
#region Properties #region Properties
/// <summary>
/// Whether the resultSet is in the process of being disposed
/// </summary>
/// <returns></returns>
internal bool IsBeingDisposed { get; private set; }
/// <summary> /// <summary>
/// The columns for this result set /// The columns for this result set
/// </summary> /// </summary>
@@ -506,9 +500,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}); });
// Add exception handling to the save task // Add exception handling to the save task
Task taskWithHandling = saveAsTask.ContinueWithOnFaulted(t => Task taskWithHandling = saveAsTask.ContinueWithOnFaulted(async t =>
{ {
failureHandler?.Invoke(saveParams, t.Exception.Message).Wait(); if (failureHandler != null)
{
await failureHandler(saveParams, t.Exception.Message);
}
}); });
// If saving the task fails, return a failure // If saving the task fails, return a failure
@@ -538,7 +535,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
return; return;
} }
IsBeingDisposed = true;
// Check if saveTasks are running for this ResultSet // Check if saveTasks are running for this ResultSet
if (!SaveTasks.IsEmpty) if (!SaveTasks.IsEmpty)
{ {
@@ -550,7 +546,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
fileStreamFactory.DisposeFile(outputFileName); fileStreamFactory.DisposeFile(outputFileName);
} }
disposed = true; disposed = true;
IsBeingDisposed = false;
}); });
} }
else else
@@ -561,14 +556,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
fileStreamFactory.DisposeFile(outputFileName); fileStreamFactory.DisposeFile(outputFileName);
} }
disposed = true; disposed = true;
IsBeingDisposed = false;
} }
} }
#endregion #endregion
#region Private Helper Methods #region Private Helper Methods
/// <summary> /// <summary>
/// If the result set represented by this class corresponds to a single XML /// If the result set represented by this class corresponds to a single XML
/// column that contains results of "for xml" query, set isXml = true /// column that contains results of "for xml" query, set isXml = true

View File

@@ -18,33 +18,83 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
/// <remarks> /// <remarks>
/// This will effectively swallow exceptions in the task chain. /// This will effectively swallow exceptions in the task chain.
/// </remarks> /// </remarks>
/// <param name="task">The task to continue</param> /// <param name="antecedent">The task to continue</param>
/// <param name="continuationAction"> /// <param name="continuationAction">
/// An optional operation to perform after exception handling has occurred /// An optional operation to perform after exception handling has occurred
/// </param> /// </param>
/// <returns>Task with exception handling on continuation</returns> /// <returns>Task with exception handling on continuation</returns>
public static Task ContinueWithOnFaulted(this Task task, Action<Task> continuationAction) public static Task ContinueWithOnFaulted(this Task antecedent, Action<Task> continuationAction)
{ {
return task.ContinueWith(t => return antecedent.ContinueWith(task =>
{ {
// If the task hasn't faulted or has an exception, skip processing // If the task hasn't faulted or has an exception, skip processing
if (!t.IsFaulted || t.Exception == null) if (!task.IsFaulted || task.Exception == null)
{ {
return; return;
} }
// Construct an error message for an aggregate exception and log it LogTaskExceptions(task.Exception);
StringBuilder sb = new StringBuilder("Unhandled exception(s) in async task:");
foreach (Exception e in task.Exception.InnerExceptions)
{
sb.AppendLine($"{e.GetType().Name}: {e.Message}");
sb.AppendLine(e.StackTrace);
}
Logger.Write(LogLevel.Error, sb.ToString());
// Run the continuation task that was provided // Run the continuation task that was provided
continuationAction?.Invoke(t); try
{
continuationAction?.Invoke(task);
}
catch (Exception e)
{
Logger.Write(LogLevel.Error, $"Exception in exception handling continuation: {e}");
Logger.Write(LogLevel.Error, e.StackTrace);
}
}); });
} }
/// <summary>
/// Adds handling to check the Exception field of a task and log it if the task faulted.
/// This version allows for async code to be ran in the continuation function.
/// </summary>
/// <remarks>
/// This will effectively swallow exceptions in the task chain.
/// </remarks>
/// <param name="antecedent">The task to continue</param>
/// <param name="continuationFunc">
/// An optional operation to perform after exception handling has occurred
/// </param>
/// <returns>Task with exception handling on continuation</returns>
public static Task ContinueWithOnFaulted(this Task antecedent, Func<Task, Task> continuationFunc)
{
return antecedent.ContinueWith(task =>
{
// If the task hasn't faulted or doesn't have an exception, skip processing
if (!task.IsFaulted || task.Exception == null)
{
return;
}
LogTaskExceptions(task.Exception);
// Run the continuation task that was provided
try
{
continuationFunc?.Invoke(antecedent).Wait();
}
catch (Exception e)
{
Logger.Write(LogLevel.Error, $"Exception in exception handling continuation: {e}");
Logger.Write(LogLevel.Error, e.StackTrace);
}
});
}
private static void LogTaskExceptions(AggregateException exception)
{
// Construct an error message for an aggregate exception and log it
StringBuilder sb = new StringBuilder("Unhandled exception(s) in async task:");
foreach (Exception e in exception.InnerExceptions)
{
sb.AppendLine($"{e.GetType().Name}: {e.Message}");
sb.AppendLine(e.StackTrace);
}
Logger.Write(LogLevel.Error, sb.ToString());
}
} }
} }

View File

@@ -12,14 +12,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
{ {
public class TaskExtensionTests public class TaskExtensionTests
{ {
#region Continue with Action
[Fact] [Fact]
public async Task ContinueWithOnFaultedNullContinuation() public async Task ContinueWithOnFaultedActionNullContinuation()
{ {
// Setup: Create a task that will definitely fault // Setup: Create a task that will definitely fault
Task failureTask = new Task(() => throw new Exception("It fail!")); Task failureTask = new Task(() => { throw new Exception("It fail!"); });
// If: I continue on fault and start the task // If: I continue on fault and start the task
Task continuationTask = failureTask.ContinueWithOnFaulted(null); Task continuationTask = failureTask.ContinueWithOnFaulted((Action<Task>)null);
failureTask.Start(); failureTask.Start();
await continuationTask; await continuationTask;
@@ -28,11 +30,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
} }
[Fact] [Fact]
public async Task ContinueWithOnFaultedContinuatation() public async Task ContinueWithOnFaultedActionContinuatation()
{ {
// Setup: // Setup:
// ... Create a new task that will definitely fault // ... Create a new task that will definitely fault
Task failureTask = new Task(() => throw new Exception("It fail!")); Task failureTask = new Task(() => { throw new Exception("It fail!"); });
// ... Create a quick continuation task that will signify if it's been called // ... Create a quick continuation task that will signify if it's been called
Task providedTask = null; Task providedTask = null;
@@ -49,5 +51,111 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
// ... The continuation action should have been called with the original failure task // ... The continuation action should have been called with the original failure task
Assert.Equal(failureTask, providedTask); Assert.Equal(failureTask, providedTask);
} }
[Fact]
public async Task ContinueWithOnFaultedActionExceptionInContinuation()
{
// Setup:
// ... Create a new task that will definitely fault
Task failureTask = new Task(() => { throw new Exception("It fail!"); });
// ... Create a quick continuation task that will signify if it's been called
Task providedTask = null;
// If: I continue on fault, with a continuation task that will fail
Action<Task> failureContinuation = task =>
{
providedTask = task;
throw new Exception("It fail!");
};
Task continuationTask = failureTask.ContinueWithOnFaulted(failureContinuation);
failureTask.Start();
await continuationTask;
// Then:
// ... The task should have completed without fault
Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status);
// ... The continuation action should have been called with the original failure task
Assert.Equal(failureTask, providedTask);
}
#endregion
#region Continue with Task
[Fact]
public async Task ContinueWithOnFaultedFuncNullContinuation()
{
// Setup: Create a task that will definitely fault
Task failureTask = new Task(() => { throw new Exception("It fail!"); });
// If: I continue on fault and start the task
// ReSharper disable once RedundantCast -- Just to enforce we're running the right overload
Task continuationTask = failureTask.ContinueWithOnFaulted((Func<Task, Task>)null);
failureTask.Start();
await continuationTask;
// Then: The task should have completed without fault
Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status);
}
[Fact]
public async Task ContinueWithOnFaultedFuncContinuatation()
{
// Setup:
// ... Create a new task that will definitely fault
Task failureTask = new Task(() => { throw new Exception("It fail!"); });
// ... Create a quick continuation task that will signify if it's been called
Task providedTask = null;
// If: I continue on fault, with a continuation task
Func<Task, Task> continuationFunc = task =>
{
providedTask = task;
return Task.CompletedTask;
};
Task continuationTask = failureTask.ContinueWithOnFaulted(continuationFunc);
failureTask.Start();
await continuationTask;
// Then:
// ... The task should have completed without fault
Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status);
// ... The continuation action should have been called with the original failure task
Assert.Equal(failureTask, providedTask);
}
[Fact]
public async Task ContinueWithOnFaultedFuncExceptionInContinuation()
{
// Setup:
// ... Create a new task that will definitely fault
Task failureTask = new Task(() => { throw new Exception("It fail!"); });
// ... Create a quick continuation task that will signify if it's been called
Task providedTask = null;
// If: I continue on fault, with a continuation task that will fail
Func<Task, Task> failureContinuation = task =>
{
providedTask = task;
throw new Exception("It fail!");
};
Task continuationTask = failureTask.ContinueWithOnFaulted(failureContinuation);
failureTask.Start();
await continuationTask;
// Then:
// ... The task should have completed without fault
Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status);
// ... The continuation action should have been called with the original failure task
Assert.Equal(failureTask, providedTask);
}
#endregion
} }
} }