diff --git a/src/Microsoft.SqlTools.ServiceLayer/FileBrowser/FileBrowserService.cs b/src/Microsoft.SqlTools.ServiceLayer/FileBrowser/FileBrowserService.cs index ea83ea7e..de555582 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/FileBrowser/FileBrowserService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/FileBrowser/FileBrowserService.cs @@ -4,8 +4,7 @@ // using System; -using System.Collections.Generic; -using System.Data; +using System.Collections.Concurrent; using System.Data.Common; using System.Data.SqlClient; using System.Threading.Tasks; @@ -14,6 +13,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.FileBrowser.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.FileBrowser { @@ -26,9 +26,9 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser public static FileBrowserService Instance => LazyInstance.Value; // Cache file browser operations for expanding node request - private Dictionary ownerToFileBrowserMap = new Dictionary(); - private Dictionary validatePathsCallbackMap = new Dictionary(); - private ConnectionService connectionService = null; + private readonly ConcurrentDictionary ownerToFileBrowserMap = new ConcurrentDictionary(); + private readonly ConcurrentDictionary validatePathsCallbackMap = new ConcurrentDictionary(); + private ConnectionService connectionService; /// /// Signature for callback method that validates the selected file paths @@ -52,23 +52,6 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser } } - /// - /// Service host object for sending/receiving requests/events. - /// Internal for testing purposes. - /// - internal IProtocolEndpoint ServiceHost - { - get; - set; - } - - /// - /// Constructor - /// - public FileBrowserService() - { - } - /// /// Register validate path callback /// @@ -76,23 +59,15 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser /// public void RegisterValidatePathsCallback(string service, ValidatePathsCallback callback) { - if (this.validatePathsCallbackMap.ContainsKey(service)) - { - this.validatePathsCallbackMap.Remove(service); - } - - this.validatePathsCallbackMap.Add(service, callback); + validatePathsCallbackMap.AddOrUpdate(service, callback, (key, oldValue) => callback); } /// /// Initializes the service instance /// - /// - /// + /// Service host to register handlers with public void InitializeService(ServiceHost serviceHost) { - this.ServiceHost = serviceHost; - // Open a file browser serviceHost.SetRequestHandler(FileBrowserOpenRequest.Type, HandleFileBrowserOpenRequest); @@ -108,13 +83,12 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser #region request handlers - internal async Task HandleFileBrowserOpenRequest( - FileBrowserOpenParams fileBrowserParams, - RequestContext requestContext) + internal async Task HandleFileBrowserOpenRequest(FileBrowserOpenParams fileBrowserParams, RequestContext requestContext) { try { - var task = Task.Run(() => RunFileBrowserOpenTask(fileBrowserParams)); + var task = Task.Run(() => RunFileBrowserOpenTask(fileBrowserParams, requestContext)) + .ContinueWithOnFaulted(null); await requestContext.SendResult(true); } catch @@ -123,13 +97,12 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser } } - internal async Task HandleFileBrowserExpandRequest( - FileBrowserExpandParams fileBrowserParams, - RequestContext requestContext) + internal async Task HandleFileBrowserExpandRequest(FileBrowserExpandParams fileBrowserParams, RequestContext requestContext) { try { - var task = Task.Run(() => RunFileBrowserExpandTask(fileBrowserParams)); + var task = Task.Run(() => RunFileBrowserExpandTask(fileBrowserParams, requestContext)) + .ContinueWithOnFaulted(null); await requestContext.SendResult(true); } catch @@ -138,13 +111,12 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser } } - internal async Task HandleFileBrowserValidateRequest( - FileBrowserValidateParams fileBrowserParams, - RequestContext requestContext) + internal async Task HandleFileBrowserValidateRequest(FileBrowserValidateParams fileBrowserParams, RequestContext requestContext) { try { - var task = Task.Run(() => RunFileBrowserValidateTask(fileBrowserParams)); + var task = Task.Run(() => RunFileBrowserValidateTask(fileBrowserParams, requestContext)) + .ContinueWithOnFaulted(null); await requestContext.SendResult(true); } catch @@ -158,22 +130,15 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser RequestContext requestContext) { FileBrowserCloseResponse response = new FileBrowserCloseResponse(); - if (this.ownerToFileBrowserMap.ContainsKey(fileBrowserParams.OwnerUri)) - { - this.ownerToFileBrowserMap.Remove(fileBrowserParams.OwnerUri); - response.Succeeded = true; - } - else - { - response.Succeeded = false; - } + FileBrowserOperation removedOperation; + response.Succeeded = ownerToFileBrowserMap.TryRemove(fileBrowserParams.OwnerUri, out removedOperation); await requestContext.SendResult(response); } #endregion - internal async Task RunFileBrowserOpenTask(FileBrowserOpenParams fileBrowserParams) + internal async Task RunFileBrowserOpenTask(FileBrowserOpenParams fileBrowserParams, RequestContext requestContext) { FileBrowserOpenedParams result = new FileBrowserOpenedParams(); @@ -189,7 +154,7 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser connInfo.TryGetConnection(ConnectionType.Default, out dbConn); if (dbConn != null) { - conn = ReliableConnectionHelper.GetAsSqlConnection((IDbConnection)dbConn); + conn = ReliableConnectionHelper.GetAsSqlConnection(dbConn); } } @@ -198,11 +163,7 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser FileBrowserOperation browser = new FileBrowserOperation(conn, fileBrowserParams.ExpandPath, fileBrowserParams.FileFilters); browser.PopulateFileTree(); - if (this.ownerToFileBrowserMap.ContainsKey(fileBrowserParams.OwnerUri)) - { - this.ownerToFileBrowserMap.Remove(fileBrowserParams.OwnerUri); - } - this.ownerToFileBrowserMap.Add(fileBrowserParams.OwnerUri, browser); + ownerToFileBrowserMap.AddOrUpdate(fileBrowserParams.OwnerUri, browser, (key, value) => browser); result.OwnerUri = fileBrowserParams.OwnerUri; result.FileTree = browser.FileTree; @@ -219,25 +180,21 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser result.Message = ex.Message; } - await ServiceHost.SendEvent(FileBrowserOpenedNotification.Type, result); + await requestContext.SendEvent(FileBrowserOpenedNotification.Type, result); } - internal async Task RunFileBrowserExpandTask(FileBrowserExpandParams fileBrowserParams) + internal async Task RunFileBrowserExpandTask(FileBrowserExpandParams fileBrowserParams, RequestContext requestContext) { FileBrowserExpandedParams result = new FileBrowserExpandedParams(); try { - if (this.ownerToFileBrowserMap.ContainsKey(fileBrowserParams.OwnerUri)) + FileBrowserOperation browser; + result.Succeeded = ownerToFileBrowserMap.TryGetValue(fileBrowserParams.OwnerUri, out browser); + if (result.Succeeded && browser != null) { - FileBrowserOperation browser = this.ownerToFileBrowserMap[fileBrowserParams.OwnerUri]; result.Children = browser.GetChildren(fileBrowserParams.ExpandPath).ToArray(); result.ExpandPath = fileBrowserParams.ExpandPath; result.OwnerUri = fileBrowserParams.OwnerUri; - result.Succeeded = true; - } - else - { - result.Succeeded = false; } } catch (Exception ex) @@ -246,22 +203,23 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser result.Message = ex.Message; } - await ServiceHost.SendEvent(FileBrowserExpandedNotification.Type, result); + await requestContext.SendEvent(FileBrowserExpandedNotification.Type, result); } - internal async Task RunFileBrowserValidateTask(FileBrowserValidateParams fileBrowserParams) + internal async Task RunFileBrowserValidateTask(FileBrowserValidateParams fileBrowserParams, RequestContext requestContext) { FileBrowserValidatedParams result = new FileBrowserValidatedParams(); try { - if (this.validatePathsCallbackMap.ContainsKey(fileBrowserParams.ServiceType) - && this.validatePathsCallbackMap[fileBrowserParams.ServiceType] != null + ValidatePathsCallback callback; + if (validatePathsCallbackMap.TryGetValue(fileBrowserParams.ServiceType, out callback) + && callback != null && fileBrowserParams.SelectedFiles != null && fileBrowserParams.SelectedFiles.Length > 0) { string errorMessage; - result.Succeeded = this.validatePathsCallbackMap[fileBrowserParams.ServiceType](new FileBrowserValidateEventArgs + result.Succeeded = callback(new FileBrowserValidateEventArgs { ServiceType = fileBrowserParams.ServiceType, OwnerUri = fileBrowserParams.OwnerUri, @@ -284,7 +242,7 @@ namespace Microsoft.SqlTools.ServiceLayer.FileBrowser result.Message = ex.Message; } - await ServiceHost.SendEvent(FileBrowserValidatedNotification.Type, result); + await requestContext.SendEvent(FileBrowserValidatedNotification.Type, result); } } } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index bfddaddc..46ab9d41 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -277,9 +277,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public void Execute() { ExecutionTask = Task.Run(ExecuteInternal) - .ContinueWithOnFaulted(t => + .ContinueWithOnFaulted(async t => { - QueryFailed?.Invoke(this, t.Exception).Wait(); + if (QueryFailed != null) + { + await QueryFailed(this, t.Exception); + } }); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index df093d33..390a774b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -136,12 +136,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Properties - /// - /// Whether the resultSet is in the process of being disposed - /// - /// - internal bool IsBeingDisposed { get; private set; } - /// /// The columns for this result set /// @@ -506,9 +500,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }); // Add exception handling to the save task - Task taskWithHandling = saveAsTask.ContinueWithOnFaulted(t => + Task taskWithHandling = saveAsTask.ContinueWithOnFaulted(async t => { - failureHandler?.Invoke(saveParams, t.Exception.Message).Wait(); + if (failureHandler != null) + { + await failureHandler(saveParams, t.Exception.Message); + } }); // If saving the task fails, return a failure @@ -538,7 +535,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - IsBeingDisposed = true; // Check if saveTasks are running for this ResultSet if (!SaveTasks.IsEmpty) { @@ -550,7 +546,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution fileStreamFactory.DisposeFile(outputFileName); } disposed = true; - IsBeingDisposed = false; }); } else @@ -561,14 +556,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution fileStreamFactory.DisposeFile(outputFileName); } disposed = true; - IsBeingDisposed = false; } } #endregion #region Private Helper Methods - + /// /// If the result set represented by this class corresponds to a single XML /// column that contains results of "for xml" query, set isXml = true diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/TaskExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/TaskExtensions.cs index 818e6d5c..a60b438b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/TaskExtensions.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/TaskExtensions.cs @@ -18,33 +18,83 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility /// /// This will effectively swallow exceptions in the task chain. /// - /// The task to continue + /// The task to continue /// /// An optional operation to perform after exception handling has occurred /// /// Task with exception handling on continuation - public static Task ContinueWithOnFaulted(this Task task, Action continuationAction) + public static Task ContinueWithOnFaulted(this Task antecedent, Action continuationAction) { - return task.ContinueWith(t => + return antecedent.ContinueWith(task => { // If the task hasn't faulted or has an exception, skip processing - if (!t.IsFaulted || t.Exception == null) + if (!task.IsFaulted || task.Exception == null) { return; } - // Construct an error message for an aggregate exception and log it - StringBuilder sb = new StringBuilder("Unhandled exception(s) in async task:"); - foreach (Exception e in task.Exception.InnerExceptions) - { - sb.AppendLine($"{e.GetType().Name}: {e.Message}"); - sb.AppendLine(e.StackTrace); - } - Logger.Write(LogLevel.Error, sb.ToString()); + LogTaskExceptions(task.Exception); // Run the continuation task that was provided - continuationAction?.Invoke(t); + try + { + continuationAction?.Invoke(task); + } + catch (Exception e) + { + Logger.Write(LogLevel.Error, $"Exception in exception handling continuation: {e}"); + Logger.Write(LogLevel.Error, e.StackTrace); + } }); } + + /// + /// Adds handling to check the Exception field of a task and log it if the task faulted. + /// This version allows for async code to be ran in the continuation function. + /// + /// + /// This will effectively swallow exceptions in the task chain. + /// + /// The task to continue + /// + /// An optional operation to perform after exception handling has occurred + /// + /// Task with exception handling on continuation + public static Task ContinueWithOnFaulted(this Task antecedent, Func continuationFunc) + { + return antecedent.ContinueWith(task => + { + // If the task hasn't faulted or doesn't have an exception, skip processing + if (!task.IsFaulted || task.Exception == null) + { + return; + } + + LogTaskExceptions(task.Exception); + + // Run the continuation task that was provided + try + { + continuationFunc?.Invoke(antecedent).Wait(); + } + catch (Exception e) + { + Logger.Write(LogLevel.Error, $"Exception in exception handling continuation: {e}"); + Logger.Write(LogLevel.Error, e.StackTrace); + } + }); + } + + private static void LogTaskExceptions(AggregateException exception) + { + // Construct an error message for an aggregate exception and log it + StringBuilder sb = new StringBuilder("Unhandled exception(s) in async task:"); + foreach (Exception e in exception.InnerExceptions) + { + sb.AppendLine($"{e.GetType().Name}: {e.Message}"); + sb.AppendLine(e.StackTrace); + } + Logger.Write(LogLevel.Error, sb.ToString()); + } } } \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TaskExtensionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TaskExtensionTests.cs index b287f197..fc95253d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TaskExtensionTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TaskExtensionTests.cs @@ -12,14 +12,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { public class TaskExtensionTests { + #region Continue with Action + [Fact] - public async Task ContinueWithOnFaultedNullContinuation() + public async Task ContinueWithOnFaultedActionNullContinuation() { // Setup: Create a task that will definitely fault - Task failureTask = new Task(() => throw new Exception("It fail!")); + Task failureTask = new Task(() => { throw new Exception("It fail!"); }); // If: I continue on fault and start the task - Task continuationTask = failureTask.ContinueWithOnFaulted(null); + Task continuationTask = failureTask.ContinueWithOnFaulted((Action)null); failureTask.Start(); await continuationTask; @@ -28,11 +30,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility } [Fact] - public async Task ContinueWithOnFaultedContinuatation() + public async Task ContinueWithOnFaultedActionContinuatation() { // Setup: // ... Create a new task that will definitely fault - Task failureTask = new Task(() => throw new Exception("It fail!")); + Task failureTask = new Task(() => { throw new Exception("It fail!"); }); // ... Create a quick continuation task that will signify if it's been called Task providedTask = null; @@ -49,5 +51,111 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility // ... The continuation action should have been called with the original failure task Assert.Equal(failureTask, providedTask); } + + [Fact] + public async Task ContinueWithOnFaultedActionExceptionInContinuation() + { + // Setup: + // ... Create a new task that will definitely fault + Task failureTask = new Task(() => { throw new Exception("It fail!"); }); + + // ... Create a quick continuation task that will signify if it's been called + Task providedTask = null; + + // If: I continue on fault, with a continuation task that will fail + Action failureContinuation = task => + { + providedTask = task; + throw new Exception("It fail!"); + }; + Task continuationTask = failureTask.ContinueWithOnFaulted(failureContinuation); + failureTask.Start(); + await continuationTask; + + // Then: + // ... The task should have completed without fault + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + + // ... The continuation action should have been called with the original failure task + Assert.Equal(failureTask, providedTask); + } + + #endregion + + #region Continue with Task + + [Fact] + public async Task ContinueWithOnFaultedFuncNullContinuation() + { + // Setup: Create a task that will definitely fault + Task failureTask = new Task(() => { throw new Exception("It fail!"); }); + + // If: I continue on fault and start the task + // ReSharper disable once RedundantCast -- Just to enforce we're running the right overload + Task continuationTask = failureTask.ContinueWithOnFaulted((Func)null); + failureTask.Start(); + await continuationTask; + + // Then: The task should have completed without fault + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + } + + [Fact] + public async Task ContinueWithOnFaultedFuncContinuatation() + { + // Setup: + // ... Create a new task that will definitely fault + Task failureTask = new Task(() => { throw new Exception("It fail!"); }); + + // ... Create a quick continuation task that will signify if it's been called + Task providedTask = null; + + // If: I continue on fault, with a continuation task + Func continuationFunc = task => + { + providedTask = task; + return Task.CompletedTask; + }; + Task continuationTask = failureTask.ContinueWithOnFaulted(continuationFunc); + failureTask.Start(); + await continuationTask; + + // Then: + // ... The task should have completed without fault + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + + // ... The continuation action should have been called with the original failure task + Assert.Equal(failureTask, providedTask); + } + + [Fact] + public async Task ContinueWithOnFaultedFuncExceptionInContinuation() + { + // Setup: + // ... Create a new task that will definitely fault + Task failureTask = new Task(() => { throw new Exception("It fail!"); }); + + // ... Create a quick continuation task that will signify if it's been called + Task providedTask = null; + + // If: I continue on fault, with a continuation task that will fail + Func failureContinuation = task => + { + providedTask = task; + throw new Exception("It fail!"); + }; + Task continuationTask = failureTask.ContinueWithOnFaulted(failureContinuation); + failureTask.Start(); + await continuationTask; + + // Then: + // ... The task should have completed without fault + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + + // ... The continuation action should have been called with the original failure task + Assert.Equal(failureTask, providedTask); + } + + #endregion } } \ No newline at end of file