From 2a688cb87f4a88d8bbbd69e5fe1469b7137e00af Mon Sep 17 00:00:00 2001 From: Sharon Ravindran Date: Fri, 21 Oct 2016 20:07:21 -0700 Subject: [PATCH] Make save result async (#107) * Make save results asynchronous * Prevent write share of file * Lock objects in stages * Create Save result objects * refactor and write rows in batches * CHange batchSize from test value * Remove await in handler * Removing the file reader as a member of the resultset * Change Dispose to wait for save * Change concurrentBag * PascalCase variables * Modify function signature and tests * Safe file methods * refactor ResultSets to Ilist and remove ToList * Change dictionary key and prevent add to saveTasks during dispose * Simplify row concatenation * Fix prevent add * Fix prevent add * Add methods to expose saveTasks and isBeingDisposed --- .../QueryExecution/Batch.cs | 2 +- .../Contracts/SaveResultsRequest.cs | 11 - .../QueryExecution/FileUtils.cs | 46 ++++ .../QueryExecution/QueryExecutionService.cs | 178 ++++--------- .../QueryExecution/ResultSet.cs | 133 ++++++---- .../QueryExecution/SaveResults.cs | 241 +++++++++++++++++- .../QueryExecution/SaveResultsTests.cs | 53 +++- 7 files changed, 468 insertions(+), 196 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs index 7d96b0b6..17901e6a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -125,7 +125,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// The result sets of the batch execution /// - public IEnumerable ResultSets + public IList ResultSets { get { return resultSets; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs index 1cf2390e..369fdb67 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/SaveResultsRequest.cs @@ -60,17 +60,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// Parameters to save results as CSV /// public class SaveResultsAsCsvRequestParams: SaveResultsRequestParams{ - - /// - /// CSV - Write values in quotes - /// - public Boolean ValueInQuotes { get; set; } - - /// - /// The encoding of the file to save results in - /// - public string FileEncoding { get; set; } - /// /// Include headers of columns in CSV /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs new file mode 100644 index 00000000..d795c2c2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/FileUtils.cs @@ -0,0 +1,46 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.IO; +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + internal static class FileUtils + { + /// + /// Checks if file exists and swallows exceptions, if any + /// + /// path of the file + /// + internal static bool SafeFileExists(string path) + { + try + { + return File.Exists(path); + } + catch (Exception) + { + // Swallow exception + return false; + } + } + + /// + /// Deletes a file and swallows exceptions, if any + /// + /// + internal static void SafeFileDelete(string path) + { + try + { + File.Delete(path); + } + catch (Exception) + { + // Swallow exception, do nothing + } + } + + } +} \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 39543c33..9819504d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -4,7 +4,6 @@ // using System; using System.Collections.Concurrent; -using System.IO; using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; @@ -16,7 +15,6 @@ using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; -using Newtonsoft.Json; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { @@ -252,7 +250,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// Process request to save a resultSet to a file in CSV format /// - public async Task HandleSaveResultsAsCsvRequest(SaveResultsAsCsvRequestParams saveParams, + internal async Task HandleSaveResultsAsCsvRequest(SaveResultsAsCsvRequestParams saveParams, RequestContext requestContext) { // retrieve query for OwnerUri @@ -265,67 +263,39 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }); return; } - try + + + ResultSet selectedResultSet = result.Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + if (!selectedResultSet.IsBeingDisposed) { - using (StreamWriter csvFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create))) + // Create SaveResults object and add success and error handlers to respective events + SaveResults saveAsCsv = new SaveResults(); + + SaveResults.AsyncSaveEventHandler successHandler = async message => { - // get the requested resultSet from query - Batch selectedBatch = result.Batches[saveParams.BatchIndex]; - ResultSet selectedResultSet = (selectedBatch.ResultSets.ToList())[saveParams.ResultSetIndex]; - int columnCount = 0; - int rowCount = 0; - int columnStartIndex = 0; - int rowStartIndex = 0; - - // set column, row counts depending on whether save request is for entire result set or a subset - if (SaveResults.isSaveSelection(saveParams)) - { - columnCount = saveParams.ColumnEndIndex.Value - saveParams.ColumnStartIndex.Value + 1; - rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; - columnStartIndex = saveParams.ColumnStartIndex.Value; - rowStartIndex =saveParams.RowStartIndex.Value; - } - else - { - columnCount = selectedResultSet.Columns.Length; - rowCount = (int)selectedResultSet.RowCount; - } - - // write column names if include headers option is chosen - if (saveParams.IncludeHeaders) - { - await csvFile.WriteLineAsync( string.Join( ",", selectedResultSet.Columns.Skip(columnStartIndex).Take(columnCount).Select( column => - SaveResults.EncodeCsvField(column.ColumnName) ?? string.Empty))); - } - - // retrieve rows and write as csv - ResultSetSubset resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex, rowCount); - foreach (var row in resultSubset.Rows) - { - await csvFile.WriteLineAsync( string.Join( ",", row.Skip(columnStartIndex).Take(columnCount).Select( field => - SaveResults.EncodeCsvField((field != null) ? field.ToString(): "NULL")))); - } - - } - - // Successfully wrote file, send success result - await requestContext.SendResult(new SaveResultRequestResult { Messages = null }); - } - catch(Exception ex) - { - // Delete file when exception occurs - if (File.Exists(saveParams.FilePath)) + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendResult(new SaveResultRequestResult { Messages = message }); + }; + saveAsCsv.SaveCompleted += successHandler; + SaveResults.AsyncSaveEventHandler errorHandler = async message => { - File.Delete(saveParams.FilePath); - } - await requestContext.SendError(ex.Message); + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendError(message); + }; + saveAsCsv.SaveFailed += errorHandler; + + saveAsCsv.SaveResultSetAsCsv(saveParams, requestContext, result); + + // Associate the ResultSet with the save task + selectedResultSet.AddSaveTask(saveParams.FilePath, saveAsCsv.SaveTask); + } } /// /// Process request to save a resultSet to a file in JSON format /// - public async Task HandleSaveResultsAsJsonRequest(SaveResultsAsJsonRequestParams saveParams, + internal async Task HandleSaveResultsAsJsonRequest(SaveResultsAsJsonRequestParams saveParams, RequestContext requestContext) { // retrieve query for OwnerUri @@ -338,73 +308,31 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution }); return; } - try + + ResultSet selectedResultSet = result.Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + if (!selectedResultSet.IsBeingDisposed) { - using (StreamWriter jsonFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create))) - using (JsonWriter jsonWriter = new JsonTextWriter(jsonFile) ) + // Create SaveResults object and add success and error handlers to respective events + SaveResults saveAsJson = new SaveResults(); + SaveResults.AsyncSaveEventHandler successHandler = async message => { - jsonWriter.Formatting = Formatting.Indented; - jsonWriter.WriteStartArray(); - - // get the requested resultSet from query - Batch selectedBatch = result.Batches[saveParams.BatchIndex]; - ResultSet selectedResultSet = selectedBatch.ResultSets.ToList()[saveParams.ResultSetIndex]; - int rowCount = 0; - int rowStartIndex = 0; - int columnStartIndex = 0; - int columnEndIndex = 0; - - // set column, row counts depending on whether save request is for entire result set or a subset - if (SaveResults.isSaveSelection(saveParams)) - { - - rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; - rowStartIndex = saveParams.RowStartIndex.Value; - columnStartIndex = saveParams.ColumnStartIndex.Value; - columnEndIndex = saveParams.ColumnEndIndex.Value + 1 ; // include the last column - } - else - { - rowCount = (int)selectedResultSet.RowCount; - columnEndIndex = selectedResultSet.Columns.Length; - } - - // retrieve rows and write as json - ResultSetSubset resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex, rowCount); - foreach (var row in resultSubset.Rows) - { - jsonWriter.WriteStartObject(); - for (int i = columnStartIndex ; i < columnEndIndex; i++) - { - //get column name - DbColumnWrapper col = selectedResultSet.Columns[i]; - string val = row[i]; - jsonWriter.WritePropertyName(col.ColumnName); - if (val == null) - { - jsonWriter.WriteNull(); - } - else - { - jsonWriter.WriteValue(val); - } - } - jsonWriter.WriteEndObject(); - } - jsonWriter.WriteEndArray(); - } - - await requestContext.SendResult(new SaveResultRequestResult { Messages = null }); - } - catch(Exception ex) - { - // Delete file when exception occurs - if (File.Exists(saveParams.FilePath)) + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendResult(new SaveResultRequestResult { Messages = message }); + }; + saveAsJson.SaveCompleted += successHandler; + SaveResults.AsyncSaveEventHandler errorHandler = async message => { - File.Delete(saveParams.FilePath); - } - await requestContext.SendError(ex.Message); + selectedResultSet.RemoveSaveTask(saveParams.FilePath); + await requestContext.SendError(message); + }; + saveAsJson.SaveFailed += errorHandler; + + saveAsJson.SaveResultSetAsJson(saveParams, requestContext, result); + + // Associate the ResultSet with the save task + selectedResultSet.AddSaveTask(saveParams.FilePath, saveAsJson.SaveTask); } + } #endregion @@ -439,27 +367,27 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution string queryText; - if (executeParams.QuerySelection != null) + if (executeParams.QuerySelection != null) { string[] queryTextArray = queryFile.GetLinesInRange( new BufferRange( new BufferPosition( - executeParams.QuerySelection.StartLine + 1, + executeParams.QuerySelection.StartLine + 1, executeParams.QuerySelection.StartColumn + 1 - ), + ), new BufferPosition( - executeParams.QuerySelection.EndLine + 1, + executeParams.QuerySelection.EndLine + 1, executeParams.QuerySelection.EndColumn + 1 ) ) ); queryText = queryTextArray.Aggregate((a, b) => a + '\r' + '\n' + b); - } - else + } + else { queryText = queryFile.Contents; } - + // If we can't add the query now, it's assumed the query is in progress Query newQuery = new Query(queryText, connectionInfo, settings, BufferFileFactory); if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index cc6a98ab..a96de759 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -4,6 +4,7 @@ // using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Data.Common; using System.Linq; @@ -44,12 +45,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// private readonly IFileStreamFactory fileStreamFactory; - /// - /// File stream reader that will be reused to make rapid-fire retrieval of result subsets - /// quick and low perf impact. - /// - private IFileStreamReader fileStreamReader; - /// /// Whether or not the result set has been read in from the database /// @@ -65,6 +60,16 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// private readonly string outputFileName; + /// + /// Whether the resultSet is in the process of being disposed + /// + private bool isBeingDisposed; + + /// + /// All save tasks currently saving this ResultSet + /// + private ConcurrentDictionary saveTasks; + #endregion /// @@ -86,10 +91,23 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // Store the factory fileStreamFactory = factory; hasBeenRead = false; + saveTasks = new ConcurrentDictionary(); } #region Properties + /// + /// Whether the resultSet is in the process of being disposed + /// + /// + internal bool IsBeingDisposed + { + get + { + return isBeingDisposed; + } + } + /// /// The columns for this result set /// @@ -120,18 +138,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public long RowCount { get; private set; } - /// - /// The rows of this result set - /// - public IEnumerable Rows - { - get - { - return FileOffsets.Select( - offset => fileStreamReader.ReadRow(offset, Columns).Select(cell => cell.DisplayValue).ToArray()); - } - } - #endregion #region Public Methods @@ -145,7 +151,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public Task GetSubset(int startRow, int rowCount) { // Sanity check to make sure that the results have been read beforehand - if (!hasBeenRead || fileStreamReader == null) + if (!hasBeenRead) { throw new InvalidOperationException(SR.QueryServiceResultSetNotRead); } @@ -164,28 +170,30 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { string[][] rows; - // If result set is 'for xml' or 'for json', - // Concatenate all the rows together into one row - if (isSingleColumnXmlJsonResultSet) + + using (IFileStreamReader fileStreamReader = fileStreamFactory.GetReader(outputFileName)) { - // Iterate over all the rows and process them into a list of string builders - IEnumerable sbRows = FileOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns) - .Select(cell => cell.DisplayValue).Aggregate(new StringBuilder(), (sb, value) => sb.Append(value))); - rows = new[] { new[] { string.Join(string.Empty, sbRows) } }; + // If result set is 'for xml' or 'for json', + // Concatenate all the rows together into one row + if (isSingleColumnXmlJsonResultSet) + { + // Iterate over all the rows and process them into a list of string builders + IEnumerable rowValues = FileOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)[0].DisplayValue); + rows = new[] { new[] { string.Join(string.Empty, rowValues) } }; + } + else + { + // Figure out which rows we need to read back + IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); + + // Iterate over the rows we need and process them into output + rows = rowOffsets.Select(rowOffset => + fileStreamReader.ReadRow(rowOffset, Columns).Select(cell => cell.DisplayValue).ToArray()) + .ToArray(); + + } } - else - { - // Figure out which rows we need to read back - IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); - - // Iterate over the rows we need and process them into output - rows = rowOffsets.Select(rowOffset => - fileStreamReader.ReadRow(rowOffset, Columns).Select(cell => cell.DisplayValue).ToArray()) - .ToArray(); - - } - // Retrieve the subset of the results as per the request return new ResultSetSubset { @@ -203,7 +211,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { // Mark that result has been read hasBeenRead = true; - fileStreamReader = fileStreamFactory.GetReader(outputFileName); // Open a writer for the file using (IFileStreamWriter fileWriter = fileStreamFactory.GetWriter(outputFileName, MaxCharsToStore, MaxXmlCharsToStore)) @@ -244,13 +251,31 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return; } - if (disposing) + isBeingDisposed = true; + // Check if saveTasks are running for this ResultSet + if (!saveTasks.IsEmpty) { - fileStreamReader?.Dispose(); - fileStreamFactory.DisposeFile(outputFileName); + // Wait for tasks to finish before disposing ResultSet + Task.WhenAll(saveTasks.Values.ToArray()).ContinueWith((antecedent) => + { + if (disposing) + { + fileStreamFactory.DisposeFile(outputFileName); + } + disposed = true; + isBeingDisposed = false; + }); + } + else + { + // If saveTasks is empty, continue with dispose + if (disposing) + { + fileStreamFactory.DisposeFile(outputFileName); + } + disposed = true; + isBeingDisposed = false; } - - disposed = true; } #endregion @@ -283,5 +308,25 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } #endregion + + #region Internal Methods to Add and Remove save tasks + internal void AddSaveTask(string key, Task saveTask) + { + saveTasks.TryAdd(key, saveTask); + } + + internal void RemoveSaveTask(string key) + { + Task completedTask; + saveTasks.TryRemove(key, out completedTask); + } + + internal Task GetSaveTask(string key) + { + Task completedTask; + saveTasks.TryRemove(key, out completedTask); + return completedTask; + } + #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs index 9188d5b0..6fd3a505 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/SaveResults.cs @@ -3,15 +3,47 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; +using System.IO; +using System.Linq; using System.Text; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Newtonsoft.Json; + namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { - internal class SaveResults{ + internal class SaveResults + { + /// + /// Number of rows being read from the ResultSubset in one read + /// + private const int BatchSize = 1000; + + /// + /// Save Task that asynchronously writes ResultSet to file + /// + internal Task SaveTask { get; set; } + + /// + /// Event Handler for save events + /// + /// Message to be returned to client + /// + internal delegate Task AsyncSaveEventHandler(string message); + + /// + /// A successful save event + /// + internal event AsyncSaveEventHandler SaveCompleted; + + /// + /// A failed save event + /// + internal event AsyncSaveEventHandler SaveFailed; /// Method ported from SSMS - /// /// Encodes a single field for inserting into a CSV record. The following rules are applied: /// @@ -32,7 +64,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution internal static String EncodeCsvField(String field) { StringBuilder sbField = new StringBuilder(field); - + //Whether this field has special characters which require it to be embedded in quotes bool embedInQuotes = false; @@ -67,12 +99,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } } - + //Replace all quotes in the original field with double quotes sbField.Replace("\"", "\"\""); String ret = sbField.ToString(); - + if (embedInQuotes) { ret = "\"" + ret + "\""; @@ -81,11 +113,208 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return ret; } - internal static bool isSaveSelection(SaveResultsRequestParams saveParams) + /// + /// Check if request is a subset of result set or whole result set + /// + /// Parameters from the request + /// + internal static bool IsSaveSelection(SaveResultsRequestParams saveParams) { return (saveParams.ColumnStartIndex != null && saveParams.ColumnEndIndex != null && saveParams.RowEndIndex != null && saveParams.RowEndIndex != null); } + + /// + /// Save results as JSON format to the file specified in saveParams + /// + /// Parameters from the request + /// Request context for save results + /// Result query object + /// + internal void SaveResultSetAsJson(SaveResultsAsJsonRequestParams saveParams, RequestContext requestContext, Query result) + { + // Run in a separate thread + SaveTask = Task.Run(async () => + { + try + { + using (StreamWriter jsonFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create, FileAccess.ReadWrite, FileShare.Read))) + using (JsonWriter jsonWriter = new JsonTextWriter(jsonFile)) + { + + int rowCount = 0; + int rowStartIndex = 0; + int columnStartIndex = 0; + int columnEndIndex = 0; + + jsonWriter.Formatting = Formatting.Indented; + jsonWriter.WriteStartArray(); + + // Get the requested resultSet from query + Batch selectedBatch = result.Batches[saveParams.BatchIndex]; + ResultSet selectedResultSet = selectedBatch.ResultSets[saveParams.ResultSetIndex]; + + // Set column, row counts depending on whether save request is for entire result set or a subset + if (IsSaveSelection(saveParams)) + { + + rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; + rowStartIndex = saveParams.RowStartIndex.Value; + columnStartIndex = saveParams.ColumnStartIndex.Value; + columnEndIndex = saveParams.ColumnEndIndex.Value + 1; // include the last column + } + else + { + rowCount = (int)selectedResultSet.RowCount; + columnEndIndex = selectedResultSet.Columns.Length; + } + + // Split rows into batches + for (int count = 0; count < (rowCount / BatchSize) + 1; count++) + { + int numberOfRows = (count < rowCount / BatchSize) ? BatchSize : (rowCount % BatchSize); + if (numberOfRows == 0) + { + break; + } + + // Retrieve rows and write as json + ResultSetSubset resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex + count * BatchSize, numberOfRows); + foreach (var row in resultSubset.Rows) + { + jsonWriter.WriteStartObject(); + for (int i = columnStartIndex; i < columnEndIndex; i++) + { + // Write columnName, value pair + DbColumnWrapper col = selectedResultSet.Columns[i]; + string val = row[i]?.ToString(); + jsonWriter.WritePropertyName(col.ColumnName); + if (val == null) + { + jsonWriter.WriteNull(); + } + else + { + jsonWriter.WriteValue(val); + } + } + jsonWriter.WriteEndObject(); + } + + } + jsonWriter.WriteEndArray(); + } + + // Successfully wrote file, send success result + if (SaveCompleted != null) + { + await SaveCompleted(null); + } + + + } + catch (Exception ex) + { + // Delete file when exception occurs + if (FileUtils.SafeFileExists(saveParams.FilePath)) + { + FileUtils.SafeFileDelete(saveParams.FilePath); + } + if (SaveFailed != null) + { + await SaveFailed(ex.ToString()); + } + } + }); + } + + /// + /// Save results as CSV format to the file specified in saveParams + /// + /// Parameters from the request + /// Request context for save results + /// Result query object + /// + internal void SaveResultSetAsCsv(SaveResultsAsCsvRequestParams saveParams, RequestContext requestContext, Query result) + { + // Run in a separate thread + SaveTask = Task.Run(async () => + { + try + { + using (StreamWriter csvFile = new StreamWriter(File.Open(saveParams.FilePath, FileMode.Create, FileAccess.ReadWrite, FileShare.Read))) + { + ResultSetSubset resultSubset; + int columnCount = 0; + int rowCount = 0; + int columnStartIndex = 0; + int rowStartIndex = 0; + + // Get the requested resultSet from query + Batch selectedBatch = result.Batches[saveParams.BatchIndex]; + ResultSet selectedResultSet = (selectedBatch.ResultSets)[saveParams.ResultSetIndex]; + // Set column, row counts depending on whether save request is for entire result set or a subset + if (IsSaveSelection(saveParams)) + { + columnCount = saveParams.ColumnEndIndex.Value - saveParams.ColumnStartIndex.Value + 1; + rowCount = saveParams.RowEndIndex.Value - saveParams.RowStartIndex.Value + 1; + columnStartIndex = saveParams.ColumnStartIndex.Value; + rowStartIndex = saveParams.RowStartIndex.Value; + } + else + { + columnCount = selectedResultSet.Columns.Length; + rowCount = (int)selectedResultSet.RowCount; + } + + // Write column names if include headers option is chosen + if (saveParams.IncludeHeaders) + { + csvFile.WriteLine(string.Join(",", selectedResultSet.Columns.Skip(columnStartIndex).Take(columnCount).Select(column => + EncodeCsvField(column.ColumnName) ?? string.Empty))); + } + + for (int i = 0; i < (rowCount / BatchSize) + 1; i++) + { + int numberOfRows = (i < rowCount / BatchSize) ? BatchSize : (rowCount % BatchSize); + if (numberOfRows == 0) + { + break; + } + // Retrieve rows and write as csv + resultSubset = await result.GetSubset(saveParams.BatchIndex, saveParams.ResultSetIndex, rowStartIndex + i * BatchSize, numberOfRows); + + foreach (var row in resultSubset.Rows) + { + csvFile.WriteLine(string.Join(",", row.Skip(columnStartIndex).Take(columnCount).Select(field => + EncodeCsvField((field != null) ? field.ToString() : "NULL")))); + } + } + } + + // Successfully wrote file, send success result + if (SaveCompleted != null) + { + await SaveCompleted(null); + } + } + catch (Exception ex) + { + // Delete file when exception occurs + if (FileUtils.SafeFileExists(saveParams.FilePath)) + { + FileUtils.SafeFileDelete(saveParams.FilePath); + } + + if (SaveFailed != null) + { + await SaveFailed(ex.Message); + } + } + }); + } + + } } \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs index 522c14a7..161e255b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SaveResultsTests.cs @@ -3,10 +3,12 @@ // using System; +using System.Linq; using System.IO; using System.Threading.Tasks; using System.Runtime.InteropServices; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; @@ -46,7 +48,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -89,7 +96,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -128,7 +140,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution string errMessage = null; var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (string) err); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see error message Assert.NotNull(errMessage); @@ -188,7 +205,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; + + // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -223,14 +247,19 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution BatchIndex = 0, FilePath = "testwrite_5.json", RowStartIndex = 0, - RowEndIndex = 0, + RowEndIndex = 1, ColumnStartIndex = 0, - ColumnEndIndex = 0 + ColumnEndIndex = 1 }; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see a file successfully created in filepath and a success message Assert.Null(result.Messages); @@ -265,11 +294,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution BatchIndex = 0, FilePath = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "G:\\test.json" : "/test.json" }; - // SaveResultRequestResult result = null; + + string errMessage = null; var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (string) err); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); - queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object).Wait(); + + // Call save results and wait on the save task + await queryService.HandleSaveResultsAsJsonRequest(saveParams, saveRequest.Object); + ResultSet selectedResultSet = queryService.ActiveQueries[saveParams.OwnerUri].Batches[saveParams.BatchIndex].ResultSets[saveParams.ResultSetIndex]; + Task saveTask = selectedResultSet.GetSaveTask(saveParams.FilePath); + await saveTask; // Expect to see error message Assert.NotNull(errMessage);