diff --git a/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs b/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs index 56f133a6..7700223a 100644 --- a/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs +++ b/src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs @@ -20,5 +20,10 @@ namespace Microsoft.SqlTools.Hosting.Contracts /// Error message /// public string Message { get; set; } + + public override string ToString() + { + return $"Error(Code={Code},Message='{Message}')"; + } } } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs index e738e932..7bfd289c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SerializationService.cs @@ -14,6 +14,7 @@ using Microsoft.SqlTools.Hosting; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.Utility; @@ -40,93 +41,111 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Begin to process request to save a resultSet to a file in CSV format /// - internal async Task HandleSerializeStartRequest(SerializeDataStartRequestParams serializeParams, + internal Task HandleSerializeStartRequest(SerializeDataStartRequestParams serializeParams, RequestContext requestContext) + { + // Run in separate thread so that message thread isn't held up by a potentially time consuming file write + Task.Run(async () => { + await RunSerializeStartRequest(serializeParams, requestContext); + }).ContinueWithOnFaulted(async t => await SendErrorAndCleanup(serializeParams?.FilePath, requestContext, t.Exception)); + return Task.CompletedTask; + } + + internal async Task RunSerializeStartRequest(SerializeDataStartRequestParams serializeParams, RequestContext requestContext) { try { + // Verify we have sensible inputs and there isn't a task running for this file already Validate.IsNotNull(nameof(serializeParams), serializeParams); Validate.IsNotNullOrWhitespaceString("FilePath", serializeParams.FilePath); DataSerializer serializer = null; - bool hasSerializer = inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer); - if (hasSerializer) + if (inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer)) { + // Cannot proceed as there is an in progress serialization happening throw new Exception(SR.SerializationServiceRequestInProgress(serializeParams.FilePath)); } - + + // Create a new serializer, save for future calls if needed, and write the request out serializer = new DataSerializer(serializeParams); if (!serializeParams.IsLastBatch) { inProgressSerializations.AddOrUpdate(serializer.FilePath, serializer, (key, old) => serializer); } - Func> writeData = () => - { - return Task.Factory.StartNew(() => - { - var result = serializer.ProcessRequest(serializeParams); - return result; - }); - }; - await HandleRequest(writeData, requestContext, "HandleSerializeStartRequest"); + + Logger.Write(TraceEventType.Verbose, "HandleSerializeStartRequest"); + SerializeDataResult result = serializer.ProcessRequest(serializeParams); + await requestContext.SendResult(result); } catch (Exception ex) { - await requestContext.SendError(ex.Message); + await SendErrorAndCleanup(serializeParams.FilePath, requestContext, ex); } } + private async Task SendErrorAndCleanup(string filePath, RequestContext requestContext, Exception ex) + { + if (filePath != null) + { + try + { + DataSerializer removed; + inProgressSerializations.TryRemove(filePath, out removed); + if (removed != null) + { + // Flush any contents to disk and remove the writer + removed.CloseStreams(); + } + } + catch + { + // Do not care if there was an error removing this, must always delete if something failed + } + } + await requestContext.SendError(ex.Message); + } + /// /// Process request to save a resultSet to a file in CSV format /// - internal async Task HandleSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams, + internal Task HandleSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams, RequestContext requestContext) + { + // Run in separate thread so that message thread isn't held up by a potentially time consuming file write + Task.Run(async () => + { + await RunSerializeContinueRequest(serializeParams, requestContext); + }).ContinueWithOnFaulted(async t => await SendErrorAndCleanup(serializeParams?.FilePath, requestContext, t.Exception)); + return Task.CompletedTask; + } + + internal async Task RunSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams, RequestContext requestContext) { try { + // Verify we have sensible inputs and some data has already been sent for the file Validate.IsNotNull(nameof(serializeParams), serializeParams); Validate.IsNotNullOrWhitespaceString("FilePath", serializeParams.FilePath); DataSerializer serializer = null; - bool hasSerializer = inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer); - if (!hasSerializer) + if (!inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer)) { throw new Exception(SR.SerializationServiceRequestNotFound(serializeParams.FilePath)); } - - Func> writeData = () => + + // Write to file and cleanup if needed + Logger.Write(TraceEventType.Verbose, "HandleSerializeContinueRequest"); + SerializeDataResult result = serializer.ProcessRequest(serializeParams); + if (serializeParams.IsLastBatch) { - return Task.Factory.StartNew(() => - { - var result = serializer.ProcessRequest(serializeParams); - if (serializeParams.IsLastBatch) - { - // Cleanup the serializer - this.inProgressSerializations.TryRemove(serializer.FilePath, out serializer); - } - return result; - }); - }; - await HandleRequest(writeData, requestContext, "HandleSerializeContinueRequest"); - } - catch (Exception ex) - { - await requestContext.SendError(ex.Message); - } - } - - private async Task HandleRequest(Func> handler, RequestContext requestContext, string requestType) - { - Logger.Write(TraceEventType.Verbose, requestType); - - try - { - T result = await handler(); + // Cleanup the serializer + this.inProgressSerializations.TryRemove(serializer.FilePath, out serializer); + } await requestContext.SendResult(result); } catch (Exception ex) { - await requestContext.SendError(ex.Message); + await SendErrorAndCleanup(serializeParams.FilePath, requestContext, ex); } } } @@ -242,9 +261,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution this.writer = factory.GetWriter(requestParams.FilePath); } } - private void CloseStreams() + public void CloseStreams() { - this.writer.Dispose(); + if (this.writer != null) + { + this.writer.Dispose(); + this.writer = null; + } } private SaveResultsAsJsonRequestParams CreateJsonRequestParams() diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs index 306a5b4e..d9487f0f 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test.Common/RequestContextMocking/EventFlowValidator.cs @@ -153,7 +153,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Common.RequestContextMocking ReceivedEvent received = ReceivedEvents[i]; // Step 1) Make sure the event type matches - Assert.Equal(expected.EventType, received.EventType); + Assert.True(expected.EventType.Equals(received.EventType), + string.Format("Expected EventType {0} but got {1}. Received object is {2}", expected.EventType, received.EventType, received.EventObject.ToString())); // Step 2) Make sure the param type matches Assert.True( expected.ParamType == received.EventObject.GetType() diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs index 30d83700..b156d2f2 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/SaveResults/SerializationServiceTests.cs @@ -92,7 +92,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults .AddStandardResultValidator() .Complete(); - await SerializationService.HandleSerializeStartRequest(saveParams, efv.Object); + await SerializationService.RunSerializeStartRequest(saveParams, efv.Object); // Then: // ... There should not have been an error @@ -189,8 +189,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults .AddStandardResultValidator() .Complete(); - await SerializationService.HandleSerializeStartRequest(request1, efv.Object); - + await SerializationService.RunSerializeStartRequest(request1, efv.Object); + // Then: // ... There should not have been an error efv.Validate(); @@ -202,7 +202,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults .AddStandardResultValidator() .Complete(); - await SerializationService.HandleSerializeContinueRequest(request1, efv.Object); + await SerializationService.RunSerializeContinueRequest(request1, efv.Object); // Then: // ... There should not have been an error @@ -260,10 +260,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.SaveResults private static void AssertLineEquals(string line, string[] expected) { var actual = line.Split(','); - Assert.True(actual.Length == expected.Length, string.Format("Line '{0}' does not match values {1}", line, string.Join(",", expected))); + Assert.True(actual.Length == expected.Length, $"Line '{line}' does not match values {string.Join(",", expected)}"); for (int i = 0; i < actual.Length; i++) { - Assert.True(expected[i] == actual[i], string.Format("Line '{0}' does not match values '{1}' as '{2}' does not equal '{3}'", line, string.Join(",", expected), expected[i], actual[i])); + Assert.True(expected[i] == actual[i], $"Line '{line}' does not match values '{string.Join(",", expected)}' as '{expected[i]}' does not equal '{actual[i]}'"); } }