diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
index 116e9111..bfddaddc 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs
@@ -1,6 +1,7 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
using System;
using System.Data.Common;
@@ -17,6 +18,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.Utility;
using Microsoft.SqlTools.ServiceLayer.BatchParser.ExecutionEngineCode;
using System.Collections.Generic;
+using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
@@ -25,10 +27,34 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
public class Query : IDisposable
{
+ #region Constants
+
///
/// "Error" code produced by SQL Server when the database context (name) for a connection changes.
///
private const int DatabaseContextChangeErrorNumber = 5701;
+
+ ///
+ /// ON keyword
+ ///
+ private const string On = "ON";
+
+ ///
+ /// OFF keyword
+ ///
+ private const string Off = "OFF";
+
+ ///
+ /// showplan_xml statement
+ ///
+ private const string SetShowPlanXml = "SET SHOWPLAN_XML {0}";
+
+ ///
+ /// statistics xml statement
+ ///
+ private const string SetStatisticsXml = "SET STATISTICS XML {0}";
+
+ #endregion
#region Member Variables
@@ -56,27 +82,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
/// Name of the new database if the database name was changed in the query
///
- private string newDatabaseName;
-
- ///
- /// ON keyword
- ///
- private const string On = "ON";
-
- ///
- /// OFF keyword
- ///
- private const string Off = "OFF";
-
- ///
- /// showplan_xml statement
- ///
- private const string SetShowPlanXml = "SET SHOWPLAN_XML {0}";
-
- ///
- /// statistics xml statement
- ///
- private const string SetStatisticsXml = "SET STATISTICS XML {0}";
+ private string newDatabaseName;
#endregion
@@ -139,6 +145,19 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
#region Events
+ ///
+ /// Delegate type for callback when a query completes or fails
+ ///
+ /// The query that completed
+ public delegate Task QueryAsyncEventHandler(Query query);
+
+ ///
+ /// Delegate type for callback when a query fails
+ ///
+ /// Query that raised the event
+ /// Exception that caused the query to fail
+ public delegate Task QueryAsyncErrorEventHandler(Query query, Exception exception);
+
///
/// Event to be called when a batch is completed.
///
@@ -154,12 +173,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
public event Batch.BatchAsyncEventHandler BatchStarted;
- ///
- /// Delegate type for callback when a query connection fails
- ///
- /// Error message for the failing query
- public delegate Task QueryAsyncErrorEventHandler(Query q, Exception e);
-
///
/// Callback for when the query has completed successfully
///
@@ -179,26 +192,20 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
#region Properties
- ///
- /// Delegate type for callback when a query completes or fails
- ///
- /// The query that completed
- public delegate Task QueryAsyncEventHandler(Query q);
-
///
/// The batches which should run before the user batches
///
- internal List BeforeBatches { get; set; }
+ private List BeforeBatches { get; }
///
/// The batches underneath this query
///
- internal Batch[] Batches { get; set; }
+ internal Batch[] Batches { get; }
///
/// The batches which should run after the user batches
///
- internal List AfterBatches { get; set; }
+ internal List AfterBatches { get; }
///
/// The summaries of the batches underneath this query
@@ -243,7 +250,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
/// The text of the query to execute
///
- public string QueryText { get; set; }
+ public string QueryText { get; }
#endregion
@@ -269,7 +276,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
public void Execute()
{
- ExecutionTask = Task.Run(ExecuteInternal);
+ ExecutionTask = Task.Run(ExecuteInternal)
+ .ContinueWithOnFaulted(t =>
+ {
+ QueryFailed?.Invoke(this, t.Exception).Wait();
+ });
}
///
@@ -338,34 +349,36 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
private async Task ExecuteInternal()
{
- // Mark that we've internally executed
- hasExecuteBeenCalled = true;
-
- // Don't actually execute if there aren't any batches to execute
- if (Batches.Length == 0)
- {
- if (BatchMessageSent != null)
- {
- await BatchMessageSent(new ResultMessage(SR.QueryServiceCompletedSuccessfully, false, null));
- }
- if (QueryCompleted != null)
- {
- await QueryCompleted(this);
- }
- return;
- }
-
- // Locate and setup the connection
- DbConnection queryConnection = await ConnectionService.Instance.GetOrOpenConnection(editorConnection.OwnerUri, ConnectionType.Query);
- ReliableSqlConnection sqlConn = queryConnection as ReliableSqlConnection;
- if (sqlConn != null)
- {
- // Subscribe to database informational messages
- sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage;
- }
-
+ ReliableSqlConnection sqlConn = null;
try
{
+ // Mark that we've internally executed
+ hasExecuteBeenCalled = true;
+
+ // Don't actually execute if there aren't any batches to execute
+ if (Batches.Length == 0)
+ {
+ if (BatchMessageSent != null)
+ {
+ await BatchMessageSent(new ResultMessage(SR.QueryServiceCompletedSuccessfully, false, null));
+ }
+ if (QueryCompleted != null)
+ {
+ await QueryCompleted(this);
+ }
+ return;
+ }
+
+ // Locate and setup the connection
+ DbConnection queryConnection = await ConnectionService.Instance.GetOrOpenConnection(editorConnection.OwnerUri, ConnectionType.Query);
+ sqlConn = queryConnection as ReliableSqlConnection;
+ if (sqlConn != null)
+ {
+ // Subscribe to database informational messages
+ sqlConn.GetUnderlyingConnection().InfoMessage += OnInfoMessage;
+ }
+
+
// Execute beforeBatches synchronously, before the user defined batches
foreach (Batch b in BeforeBatches)
{
@@ -393,7 +406,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
if (QueryCompleted != null)
{
await QueryCompleted(this);
- }
+ }
}
catch (Exception e)
{
@@ -405,16 +418,18 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
finally
{
+ // Remove the message handler from the connection
if (sqlConn != null)
{
// Subscribe to database informational messages
sqlConn.GetUnderlyingConnection().InfoMessage -= OnInfoMessage;
}
- }
-
- if (newDatabaseName != null)
- {
- ConnectionService.Instance.ChangeConnectionDatabaseContext(editorConnection.OwnerUri, newDatabaseName);
+
+ // If any message notified us we had changed databases, then we must let the connection service know
+ if (newDatabaseName != null)
+ {
+ ConnectionService.Instance.ChangeConnectionDatabaseContext(editorConnection.OwnerUri, newDatabaseName);
+ }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
index 3263d9ce..df093d33 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
@@ -504,9 +504,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
}
});
+
+ // Add exception handling to the save task
+ Task taskWithHandling = saveAsTask.ContinueWithOnFaulted(t =>
+ {
+ failureHandler?.Invoke(saveParams, t.Exception.Message).Wait();
+ });
// If saving the task fails, return a failure
- if (!SaveTasks.TryAdd(saveParams.FilePath, saveAsTask))
+ if (!SaveTasks.TryAdd(saveParams.FilePath, taskWithHandling))
{
throw new InvalidOperationException(SR.QueryServiceSaveAsMiscStartingError);
}
@@ -537,7 +543,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
if (!SaveTasks.IsEmpty)
{
// Wait for tasks to finish before disposing ResultSet
- Task.WhenAll(SaveTasks.Values.ToArray()).ContinueWith((antecedent) =>
+ Task.WhenAll(SaveTasks.Values.ToArray()).ContinueWith(antecedent =>
{
if (disposing)
{
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/TaskExtensions.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/TaskExtensions.cs
new file mode 100644
index 00000000..818e6d5c
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/TaskExtensions.cs
@@ -0,0 +1,50 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.Utility
+{
+ public static class TaskExtensions
+ {
+ ///
+ /// Adds handling to check the Exception field of a task and log it if the task faulted
+ ///
+ ///
+ /// This will effectively swallow exceptions in the task chain.
+ ///
+ /// The task to continue
+ ///
+ /// An optional operation to perform after exception handling has occurred
+ ///
+ /// Task with exception handling on continuation
+ public static Task ContinueWithOnFaulted(this Task task, Action continuationAction)
+ {
+ return task.ContinueWith(t =>
+ {
+ // If the task hasn't faulted or has an exception, skip processing
+ if (!t.IsFaulted || t.Exception == null)
+ {
+ return;
+ }
+
+ // Construct an error message for an aggregate exception and log it
+ StringBuilder sb = new StringBuilder("Unhandled exception(s) in async task:");
+ foreach (Exception e in task.Exception.InnerExceptions)
+ {
+ sb.AppendLine($"{e.GetType().Name}: {e.Message}");
+ sb.AppendLine(e.StackTrace);
+ }
+ Logger.Write(LogLevel.Error, sb.ToString());
+
+ // Run the continuation task that was provided
+ continuationAction?.Invoke(t);
+ });
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj
index e74ed27a..012db001 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Microsoft.SqlTools.ServiceLayer.UnitTests.csproj
@@ -2,8 +2,8 @@
Exe
netcoreapp2.0
- false
- $(DefineConstants);NETCOREAPP1_0
+ false
+ $(DefineConstants);NETCOREAPP1_0
false
@@ -12,10 +12,10 @@
-
-
+
+
-
+
../../bin/ref/Newtonsoft.Json.dll
@@ -39,4 +39,4 @@
-
+
\ No newline at end of file
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TaskExtensionTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TaskExtensionTests.cs
new file mode 100644
index 00000000..b287f197
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TaskExtensionTests.cs
@@ -0,0 +1,53 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Threading.Tasks;
+using Microsoft.SqlTools.ServiceLayer.Utility;
+using Xunit;
+
+namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
+{
+ public class TaskExtensionTests
+ {
+ [Fact]
+ public async Task ContinueWithOnFaultedNullContinuation()
+ {
+ // Setup: Create a task that will definitely fault
+ Task failureTask = new Task(() => throw new Exception("It fail!"));
+
+ // If: I continue on fault and start the task
+ Task continuationTask = failureTask.ContinueWithOnFaulted(null);
+ failureTask.Start();
+ await continuationTask;
+
+ // Then: The task should have completed without fault
+ Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status);
+ }
+
+ [Fact]
+ public async Task ContinueWithOnFaultedContinuatation()
+ {
+ // Setup:
+ // ... Create a new task that will definitely fault
+ Task failureTask = new Task(() => throw new Exception("It fail!"));
+
+ // ... Create a quick continuation task that will signify if it's been called
+ Task providedTask = null;
+
+ // If: I continue on fault, with a continuation task
+ Task continuationTask = failureTask.ContinueWithOnFaulted(task => { providedTask = task; });
+ failureTask.Start();
+ await continuationTask;
+
+ // Then:
+ // ... The task should have completed without fault
+ Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status);
+
+ // ... The continuation action should have been called with the original failure task
+ Assert.Equal(failureTask, providedTask);
+ }
+ }
+}
\ No newline at end of file