diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs
index 56f133a6..7700223a 100644
--- a/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs
+++ b/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs
@@ -20,5 +20,10 @@ namespace Microsoft.SqlTools.Hosting.Contracts
/// Error message
///
public string Message { get; set; }
+
+ public override string ToString()
+ {
+ return $"Error(Code={Code},Message='{Message}')";
+ }
}
}
\ No newline at end of file
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs
index e738e932..7bfd289c 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs
@@ -14,6 +14,7 @@ using Microsoft.SqlTools.Hosting;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
+using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;
@@ -40,93 +41,111 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
/// Begin to process request to save a resultSet to a file in CSV format
///
- internal async Task HandleSerializeStartRequest(SerializeDataStartRequestParams serializeParams,
+ internal Task HandleSerializeStartRequest(SerializeDataStartRequestParams serializeParams,
RequestContext requestContext)
+ {
+ // Run in separate thread so that message thread isn't held up by a potentially time consuming file write
+ Task.Run(async () => {
+ await RunSerializeStartRequest(serializeParams, requestContext);
+ }).ContinueWithOnFaulted(async t => await SendErrorAndCleanup(serializeParams?.FilePath, requestContext, t.Exception));
+ return Task.CompletedTask;
+ }
+
+ internal async Task RunSerializeStartRequest(SerializeDataStartRequestParams serializeParams, RequestContext requestContext)
{
try
{
+ // Verify we have sensible inputs and there isn't a task running for this file already
Validate.IsNotNull(nameof(serializeParams), serializeParams);
Validate.IsNotNullOrWhitespaceString("FilePath", serializeParams.FilePath);
DataSerializer serializer = null;
- bool hasSerializer = inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer);
- if (hasSerializer)
+ if (inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer))
{
+ // Cannot proceed as there is an in progress serialization happening
throw new Exception(SR.SerializationServiceRequestInProgress(serializeParams.FilePath));
}
-
+
+ // Create a new serializer, save for future calls if needed, and write the request out
serializer = new DataSerializer(serializeParams);
if (!serializeParams.IsLastBatch)
{
inProgressSerializations.AddOrUpdate(serializer.FilePath, serializer, (key, old) => serializer);
}
- Func> writeData = () =>
- {
- return Task.Factory.StartNew(() =>
- {
- var result = serializer.ProcessRequest(serializeParams);
- return result;
- });
- };
- await HandleRequest(writeData, requestContext, "HandleSerializeStartRequest");
+
+ Logger.Write(TraceEventType.Verbose, "HandleSerializeStartRequest");
+ SerializeDataResult result = serializer.ProcessRequest(serializeParams);
+ await requestContext.SendResult(result);
}
catch (Exception ex)
{
- await requestContext.SendError(ex.Message);
+ await SendErrorAndCleanup(serializeParams.FilePath, requestContext, ex);
}
}
+ private async Task SendErrorAndCleanup(string filePath, RequestContext requestContext, Exception ex)
+ {
+ if (filePath != null)
+ {
+ try
+ {
+ DataSerializer removed;
+ inProgressSerializations.TryRemove(filePath, out removed);
+ if (removed != null)
+ {
+ // Flush any contents to disk and remove the writer
+ removed.CloseStreams();
+ }
+ }
+ catch
+ {
+ // Do not care if there was an error removing this, must always delete if something failed
+ }
+ }
+ await requestContext.SendError(ex.Message);
+ }
+
///
/// Process request to save a resultSet to a file in CSV format
///
- internal async Task HandleSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams,
+ internal Task HandleSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams,
RequestContext requestContext)
+ {
+ // Run in separate thread so that message thread isn't held up by a potentially time consuming file write
+ Task.Run(async () =>
+ {
+ await RunSerializeContinueRequest(serializeParams, requestContext);
+ }).ContinueWithOnFaulted(async t => await SendErrorAndCleanup(serializeParams?.FilePath, requestContext, t.Exception));
+ return Task.CompletedTask;
+ }
+
+ internal async Task RunSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams, RequestContext requestContext)
{
try
{
+ // Verify we have sensible inputs and some data has already been sent for the file
Validate.IsNotNull(nameof(serializeParams), serializeParams);
Validate.IsNotNullOrWhitespaceString("FilePath", serializeParams.FilePath);
DataSerializer serializer = null;
- bool hasSerializer = inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer);
- if (!hasSerializer)
+ if (!inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer))
{
throw new Exception(SR.SerializationServiceRequestNotFound(serializeParams.FilePath));
}
-
- Func> writeData = () =>
+
+ // Write to file and cleanup if needed
+ Logger.Write(TraceEventType.Verbose, "HandleSerializeContinueRequest");
+ SerializeDataResult result = serializer.ProcessRequest(serializeParams);
+ if (serializeParams.IsLastBatch)
{
- return Task.Factory.StartNew(() =>
- {
- var result = serializer.ProcessRequest(serializeParams);
- if (serializeParams.IsLastBatch)
- {
- // Cleanup the serializer
- this.inProgressSerializations.TryRemove(serializer.FilePath, out serializer);
- }
- return result;
- });
- };
- await HandleRequest(writeData, requestContext, "HandleSerializeContinueRequest");
- }
- catch (Exception ex)
- {
- await requestContext.SendError(ex.Message);
- }
- }
-
- private async Task HandleRequest(Func> handler, RequestContext requestContext, string requestType)
- {
- Logger.Write(TraceEventType.Verbose, requestType);
-
- try
- {
- T result = await handler();
+ // Cleanup the serializer
+ this.inProgressSerializations.TryRemove(serializer.FilePath, out serializer);
+ }
await requestContext.SendResult(result);
}
catch (Exception ex)
{
- await requestContext.SendError(ex.Message);
+ await SendErrorAndCleanup(serializeParams.FilePath, requestContext, ex);
}
}
}
@@ -242,9 +261,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
this.writer = factory.GetWriter(requestParams.FilePath);
}
}
- private void CloseStreams()
+ public void CloseStreams()
{
- this.writer.Dispose();
+ if (this.writer != null)
+ {
+ this.writer.Dispose();
+ this.writer = null;
+ }
}
private SaveResultsAsJsonRequestParams CreateJsonRequestParams()
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs
index 306a5b4e..d9487f0f 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs
@@ -153,7 +153,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common.RequestContextMocking
ReceivedEvent received = ReceivedEvents[i];
// Step 1) Make sure the event type matches
- Assert.Equal(expected.EventType, received.EventType);
+ Assert.True(expected.EventType.Equals(received.EventType),
+ string.Format("Expected EventType {0} but got {1}. Received object is {2}", expected.EventType, received.EventType, received.EventObject.ToString()));
// Step 2) Make sure the param type matches
Assert.True( expected.ParamType == received.EventObject.GetType()
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs
index 30d83700..b156d2f2 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs
@@ -92,7 +92,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
.AddStandardResultValidator()
.Complete();
- await SerializationService.HandleSerializeStartRequest(saveParams, efv.Object);
+ await SerializationService.RunSerializeStartRequest(saveParams, efv.Object);
// Then:
// ... There should not have been an error
@@ -189,8 +189,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
.AddStandardResultValidator()
.Complete();
- await SerializationService.HandleSerializeStartRequest(request1, efv.Object);
-
+ await SerializationService.RunSerializeStartRequest(request1, efv.Object);
+
// Then:
// ... There should not have been an error
efv.Validate();
@@ -202,7 +202,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
.AddStandardResultValidator()
.Complete();
- await SerializationService.HandleSerializeContinueRequest(request1, efv.Object);
+ await SerializationService.RunSerializeContinueRequest(request1, efv.Object);
// Then:
// ... There should not have been an error
@@ -260,10 +260,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults
private static void AssertLineEquals(string line, string[] expected)
{
var actual = line.Split(',');
- Assert.True(actual.Length == expected.Length, string.Format("Line '{0}' does not match values {1}", line, string.Join(",", expected)));
+ Assert.True(actual.Length == expected.Length, $"Line '{line}' does not match values {string.Join(",", expected)}");
for (int i = 0; i < actual.Length; i++)
{
- Assert.True(expected[i] == actual[i], string.Format("Line '{0}' does not match values '{1}' as '{2}' does not equal '{3}'", line, string.Join(",", expected), expected[i], actual[i]));
+ Assert.True(expected[i] == actual[i], $"Line '{line}' does not match values '{string.Join(",", expected)}' as '{expected[i]}' does not equal '{actual[i]}'");
}
}