diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs new file mode 100644 index 00000000..78c2c72f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs @@ -0,0 +1,70 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the save results request + /// + public class SaveResultsRequestParams + { + /// + /// The path of the file to save results in + /// + public string FilePath { get; set; } + + /// + /// The encoding of the file to save results in + /// + public string FileEncoding { get; set; } + + /// + /// Include headers of columns in CSV + /// + public bool IncludeHeaders { get; set; } + + /// + /// Index of the batch to get the results from + /// + public int BatchIndex { get; set; } + + /// + /// Index of the result set to get the results from + /// + public int ResultSetIndex { get; set; } + + /// + /// CSV - Write values in quotes + /// + public Boolean ValueInQuotes { get; set; } + + /// + /// URI for the editor that called save results + /// + public string OwnerUri { get; set; } + } + + /// + /// Parameters for the save results result + /// + public class SaveResultRequestResult + { + /// + /// Error messages for saving to file. + /// + public string Messages { get; set; } + } + + public class SaveResultsAsCsvRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/save"); + } + +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 97d89fc9..882cf11b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -2,9 +2,10 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // - using System; using System.Collections.Concurrent; +using System.IO; +using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Hosting; @@ -98,6 +99,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest); serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest); serviceHost.SetRequestHandler(QueryCancelRequest.Type, HandleCancelRequest); + serviceHost.SetRequestHandler(SaveResultsAsCsvRequest.Type, HandleSaveResultsAsCsvRequest); // Register handler for shutdown event serviceHost.RegisterShutdownTask((shutdownParams, requestContext) => @@ -256,6 +258,58 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } + /// + /// Process request to save a resultSet to a file in CSV format + /// + public async Task HandleSaveResultsAsCsvRequest( SaveResultsRequestParams saveParams, + RequestContext requestContext) + { + // retrieve query for OwnerUri + Query result; + if (!ActiveQueries.TryGetValue(saveParams.OwnerUri, out result)) + { + await requestContext.SendResult(new SaveResultRequestResult + { + Messages = "Failed to save results, ID not found." + }); + return; + } + try + { + using (StreamWriter csvFile = new StreamWriter(File.OpenWrite(saveParams.FilePath))) + { + // get the requested resultSet from query + Batch selectedBatch = result.Batches[saveParams.BatchIndex]; + ResultSet selectedResultSet = (selectedBatch.ResultSets.ToList())[saveParams.ResultSetIndex]; + if ( saveParams.IncludeHeaders) + { + // write column names to csv + await csvFile.WriteLineAsync( string.Join( ",", selectedResultSet.Columns.Select( column => SaveResults.EncodeCsvField(column.ColumnName) ?? string.Empty))); + } + + // write rows to csv + foreach( var row in selectedResultSet.Rows) + { + await csvFile.WriteLineAsync(string.Join( ",", row.Select( field => SaveResults.EncodeCsvField( (field != null) ? field.ToString(): string.Empty)))); + } + } + } + catch(Exception ex) + { + // Delete file when exception occurs + if(File.Exists(saveParams.FilePath)) + { + File.Delete(saveParams.FilePath); + } + await requestContext.SendError(ex.Message); + return; + } + await requestContext.SendResult(new SaveResultRequestResult + { + Messages = "Success" + }); + return; + } #endregion #region Private Helpers diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index 84e18c99..41978255 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -111,6 +111,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public long RowCount { get; private set; } + /// + /// The rows of this result set + /// + public IEnumerable Rows + { + get { return FileOffsets.Select(offset => fileStreamReader.ReadRow(offset, Columns)); } + } + #endregion #region Public Methods diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs new file mode 100644 index 00000000..f63f253f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs @@ -0,0 +1,84 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.Text; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + internal class SaveResults{ + + /// Method ported from SSMS + + /// + /// Encodes a single field for inserting into a CSV record. The following rules are applied: + /// + /// All double quotes (") are replaced with a pair of consecutive double quotes + /// + /// The entire field is also surrounded by a pair of double quotes if any of the following conditions are met: + /// + /// The field begins or ends with a space + /// The field begins or ends with a tab + /// The field contains the ListSeparator string + /// The field contains the '\n' character + /// The field contains the '\r' character + /// The field contains the '"' character + /// + /// + /// The field to encode + /// The CSV encoded version of the original field + internal static String EncodeCsvField(String field) + { + StringBuilder sbField = new StringBuilder(field); + + //Whether this field has special characters which require it to be embedded in quotes + bool embedInQuotes = false; + + //Check for leading/trailing spaces + if (sbField.Length > 0 && + (sbField[0] == ' ' || + sbField[0] == '\t' || + sbField[sbField.Length - 1] == ' ' || + sbField[sbField.Length - 1] == '\t')) + { + embedInQuotes = true; + } + else + { //List separator being in the field will require quotes + if (field.Contains(",")) + { + embedInQuotes = true; + } + else + { + for (int i = 0; i < sbField.Length; ++i) + { + //Check whether this character is a special character + if (sbField[i] == '\r' || + sbField[i] == '\n' || + sbField[i] == '"') + { //If even one character requires embedding the whole field will + //be embedded in quotes so we can just break out now + embedInQuotes = true; + break; + } + } + } + } + + //Replace all quotes in the original field with double quotes + sbField.Replace("\"", "\"\""); + + String ret = sbField.ToString(); + + if (embedInQuotes) + { + ret = "\"" + ret + "\""; + } + + return ret; + } + } + +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs new file mode 100644 index 00000000..61cf5fea --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Threading.Tasks; +using System.Runtime.InteropServices; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + /// + /// Tests for saving a result set to a file + /// + public class SaveResultsTests + { + /// + /// Test save results to a file as CSV with correct parameters + /// + [Fact] + public void SaveResultsAsCsvSuccessTest() + { + // Execute a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; + var executeRequest = GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // Request to save the results as csv with correct parameters + var saveParams = new SaveResultsRequestParams { OwnerUri = Common.OwnerUri, ResultSetIndex = 0, BatchIndex = 0 }; + saveParams.FilePath = "testwrite.csv"; + saveParams.IncludeHeaders = true; + SaveResultRequestResult result = null; + var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); + queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); + queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Expect to see a file successfully created in filepath and a success message + Assert.Equal("Success", result.Messages); + Assert.True(File.Exists(saveParams.FilePath)); + VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); + + // Delete temp file after test + if(File.Exists(saveParams.FilePath)) + { + File.Delete(saveParams.FilePath); + } + } + + /// + /// Test handling exception in saving results to file + /// + [Fact] + public void SaveResultsAsCsvExceptionTest() + { + // Execute a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; + var executeRequest = GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // Request to save the results as csv with incorrect filepath + var saveParams = new SaveResultsRequestParams { OwnerUri = Common.OwnerUri, ResultSetIndex = 0, BatchIndex = 0 }; + if ( RuntimeInformation.IsOSPlatform( OSPlatform.Windows)) + { + saveParams.FilePath = "G:\\test.csv"; + } + else + { + saveParams.FilePath = "/test.csv"; + } + // SaveResultRequestResult result = null; + String errMessage = null; + var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (String) err); + queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); + queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Expect to see error message + Assert.NotNull(errMessage); + VerifySaveResultsCallCount(saveRequest, Times.Never(), Times.Once()); + Assert.False(File.Exists(saveParams.FilePath)); + } + + /// + /// Test saving results to file when the requested result set is no longer active + /// + [Fact] + public void SaveResultsAsCsvQueryNotFoundTest() + { + // Execute a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; + var executeRequest = GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // Request to save the results as csv with query that is no longer active + var saveParams = new SaveResultsRequestParams { OwnerUri = "falseuri", ResultSetIndex = 0, BatchIndex = 0 }; + saveParams.FilePath = "testwrite.csv"; + SaveResultRequestResult result = null; + var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); + // queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); + queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Expect message that save failed + Assert.Equal("Failed to save results, ID not found.", result.Messages); + Assert.False(File.Exists(saveParams.FilePath)); + VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); + } + + #region Mocking + + /// + /// Mock the requestContext for saving a result set + /// + /// + /// + /// + private static Mock> GetSaveResultsContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny ())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + /// + /// Verify the call count for sendResult and error + /// + /// + /// + /// + private static void VerifySaveResultsCallCount(Mock> mock, + Times sendResultCalls, Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + /// + /// Mock request context for executing a query + /// + /// + /// + /// + /// + /// + public static Mock> GetQueryExecuteResultContextMock( + Action resultCallback, + Action, QueryExecuteCompleteParams> eventCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendEvent + var sendEventFlow = requestContext.Setup(rc => rc.SendEvent( + It.Is>(m => m == QueryExecuteCompleteEvent.Type), + It.IsAny())) + .Returns(Task.FromResult(0)); + if (eventCallback != null) + { + sendEventFlow.Callback(eventCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext.Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + + #endregion + + } +}