From 93a75f1ff49eccb11d84a2846aaf43decc40528d Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Thu, 22 Sep 2016 12:00:32 -0700 Subject: [PATCH] Format Cell Values (#62) * WIP for ability to localize cell values * Changing how DateTimeOffsets are stored, getting unit tests going * Reworking BufferFileStreamWriter to use dictionary approach * Plumbing the DbCellValue type the rest of the way through * Removing unused components to simplify contract * Cleanup and making sure byte[] appears in parity with SSMS * CR comments, small tweaks for optimizing LINQ --- .../QueryExecution/Contracts/DbCellValue.cs | 23 + .../Contracts/ResultSetSubset.cs | 2 +- .../DataStorage/FileStreamReadResult.cs | 28 +- .../DataStorage/IFileStreamReader.cs | 35 +- .../DataStorage/IFileStreamWriter.cs | 4 +- .../ServiceBufferFileStreamReader.cs | 782 +++++------------- .../ServiceBufferFileStreamWriter.cs | 477 ++++------- .../DataStorage/StorageDataReader.cs | 6 + .../QueryExecution/QueryExecutionService.cs | 29 +- .../QueryExecution/ResultSet.cs | 12 +- .../QueryExecution/Common.cs | 4 + ...erviceBufferFileStreamReaderWriterTests.cs | 45 +- .../QueryExecution/SubsetTests.cs | 65 +- 13 files changed, 535 insertions(+), 977 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs new file mode 100644 index 00000000..6eabb4d3 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbCellValue.cs @@ -0,0 +1,23 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Class used for internally passing results from a cell around. + /// + public class DbCellValue + { + /// + /// Display value for the cell, suitable to be passed back to the client + /// + public string DisplayValue { get; set; } + + /// + /// The raw object for the cell, for use internally + /// + internal object RawObject { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs index 8e2b49a9..62308824 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs @@ -19,6 +19,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// /// 2D array of the cell values requested from result set /// - public object[][] Rows { get; set; } + public string[][] Rows { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs index 61ee62e0..0939c9d4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs @@ -3,25 +3,16 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { /// /// Represents a value returned from a read from a file stream. This is used to eliminate ref /// parameters used in the read methods. /// - /// The type of the value that was read - public struct FileStreamReadResult + public struct FileStreamReadResult { - /// - /// Whether or not the value of the field is null - /// - public bool IsNull { get; set; } - - /// - /// The value of the field. If is true, this will be set to default(T) - /// - public T Value { get; set; } - /// /// The total length in bytes of the value, (including the bytes used to store the length /// of the value) @@ -34,17 +25,20 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// public int TotalLength { get; set; } + /// + /// Value of the cell + /// + public DbCellValue Value { get; set; } + /// /// Constructs a new FileStreamReadResult /// - /// The value of the result - /// The number of bytes for the used to store the value's length and value - /// Whether or not the value is null - public FileStreamReadResult(T value, int totalLength, bool isNull) + /// The value of the result, ready for consumption by a client + /// The number of bytes for the used to store the value's length and values + public FileStreamReadResult(DbCellValue value, int totalLength) { Value = value; TotalLength = totalLength; - IsNull = isNull; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs index ea5584f1..cfbe4fa1 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; -using System.Data.SqlTypes; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage @@ -15,21 +14,23 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// public interface IFileStreamReader : IDisposable { - object[] ReadRow(long offset, IEnumerable columns); - FileStreamReadResult ReadInt16(long i64Offset); - FileStreamReadResult ReadInt32(long i64Offset); - FileStreamReadResult ReadInt64(long i64Offset); - FileStreamReadResult ReadByte(long i64Offset); - FileStreamReadResult ReadChar(long i64Offset); - FileStreamReadResult ReadBoolean(long i64Offset); - FileStreamReadResult ReadSingle(long i64Offset); - FileStreamReadResult ReadDouble(long i64Offset); - FileStreamReadResult ReadSqlDecimal(long i64Offset); - FileStreamReadResult ReadDecimal(long i64Offset); - FileStreamReadResult ReadDateTime(long i64Offset); - FileStreamReadResult ReadTimeSpan(long i64Offset); - FileStreamReadResult ReadString(long i64Offset); - FileStreamReadResult ReadBytes(long i64Offset); - FileStreamReadResult ReadDateTimeOffset(long i64Offset); + IList ReadRow(long offset, IEnumerable columns); + FileStreamReadResult ReadInt16(long i64Offset); + FileStreamReadResult ReadInt32(long i64Offset); + FileStreamReadResult ReadInt64(long i64Offset); + FileStreamReadResult ReadByte(long i64Offset); + FileStreamReadResult ReadChar(long i64Offset); + FileStreamReadResult ReadBoolean(long i64Offset); + FileStreamReadResult ReadSingle(long i64Offset); + FileStreamReadResult ReadDouble(long i64Offset); + FileStreamReadResult ReadSqlDecimal(long i64Offset); + FileStreamReadResult ReadDecimal(long i64Offset); + FileStreamReadResult ReadDateTime(long i64Offset); + FileStreamReadResult ReadTimeSpan(long i64Offset); + FileStreamReadResult ReadString(long i64Offset); + FileStreamReadResult ReadBytes(long i64Offset); + FileStreamReadResult ReadDateTimeOffset(long i64Offset); + FileStreamReadResult ReadGuid(long offset); + FileStreamReadResult ReadMoney(long offset); } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs index 968701ed..7cfffee8 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs @@ -29,7 +29,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage int WriteDateTimeOffset(DateTimeOffset dtoVal); int WriteTimeSpan(TimeSpan val); int WriteString(string val); - int WriteBytes(byte[] bytes, int length); + int WriteBytes(byte[] bytes); + int WriteGuid(Guid val); + int WriteMoney(SqlMoney val); void FlushBuffer(); } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs index 9772744f..ab28be89 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs @@ -6,7 +6,6 @@ using System; using System.Collections.Generic; using System.Data.SqlTypes; -using System.Diagnostics; using System.IO; using System.Text; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; @@ -26,6 +25,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage private readonly IFileStreamWrapper fileStream; + private Dictionary> readMethods; + #endregion /// @@ -41,6 +42,40 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // Create internal buffer buffer = new byte[DefaultBufferSize]; + + // Create the methods that will be used to read back + readMethods = new Dictionary> + { + {typeof(string), ReadString}, + {typeof(short), ReadInt16}, + {typeof(int), ReadInt32}, + {typeof(long), ReadInt64}, + {typeof(byte), ReadByte}, + {typeof(char), ReadChar}, + {typeof(bool), ReadBoolean}, + {typeof(double), ReadDouble}, + {typeof(float), ReadSingle}, + {typeof(decimal), ReadDecimal}, + {typeof(DateTime), ReadDateTime}, + {typeof(DateTimeOffset), ReadDateTimeOffset}, + {typeof(TimeSpan), ReadTimeSpan}, + {typeof(byte[]), ReadBytes}, + + {typeof(SqlString), ReadString}, + {typeof(SqlInt16), ReadInt16}, + {typeof(SqlInt32), ReadInt32}, + {typeof(SqlInt64), ReadInt64}, + {typeof(SqlByte), ReadByte}, + {typeof(SqlBoolean), ReadBoolean}, + {typeof(SqlDouble), ReadDouble}, + {typeof(SqlSingle), ReadSingle}, + {typeof(SqlDecimal), ReadSqlDecimal}, + {typeof(SqlDateTime), ReadDateTime}, + {typeof(SqlBytes), ReadBytes}, + {typeof(SqlBinary), ReadBytes}, + {typeof(SqlGuid), ReadGuid}, + {typeof(SqlMoney), ReadMoney}, + }; } #region IFileStreamStorage Implementation @@ -50,12 +85,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file where the row starts /// The columns that were encoded - /// The objects from the row - public object[] ReadRow(long fileOffset, IEnumerable columns) + /// The objects from the row, ready for output to the client + public IList ReadRow(long fileOffset, IEnumerable columns) { // Initialize for the loop long currentFileOffset = fileOffset; - List results = new List(); + List results = new List(); // Iterate over the columns foreach (DbColumnWrapper column in columns) @@ -65,22 +100,23 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage if (column.IsSqlVariant) { // For SQL Variant columns, the type is written first in string format - FileStreamReadResult sqlVariantTypeResult = ReadString(currentFileOffset); + FileStreamReadResult sqlVariantTypeResult = ReadString(currentFileOffset); currentFileOffset += sqlVariantTypeResult.TotalLength; + string sqlVariantType = (string)sqlVariantTypeResult.Value.RawObject; // If the typename is null, then the whole value is null - if (sqlVariantTypeResult.IsNull) + if (sqlVariantTypeResult.Value == null) { - results.Add(null); + results.Add(sqlVariantTypeResult.Value); continue; } // The typename is stored in the string - colType = Type.GetType(sqlVariantTypeResult.Value); + colType = Type.GetType(sqlVariantType); // Workaround .NET bug, see sqlbu# 440643 and vswhidbey# 599834 // TODO: Is this workaround necessary for .NET Core? - if (colType == null && sqlVariantTypeResult.Value == @"System.Data.SqlTypes.SqlSingle") + if (colType == null && sqlVariantType == "System.Data.SqlTypes.SqlSingle") { colType = typeof(SqlSingle); } @@ -90,380 +126,19 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage colType = column.DataType; } - if (colType == typeof(string)) - { - // String - most frequently used data type - FileStreamReadResult result = ReadString(currentFileOffset); - currentFileOffset += result.TotalLength; - results.Add(result.IsNull ? null : result.Value); - } - else if (colType == typeof(SqlString)) - { - // SqlString - FileStreamReadResult result = ReadString(currentFileOffset); - currentFileOffset += result.TotalLength; - results.Add(result.IsNull ? null : (SqlString) result.Value); - } - else if (colType == typeof(short)) - { - // Int16 - FileStreamReadResult result = ReadInt16(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlInt16)) - { - // SqlInt16 - FileStreamReadResult result = ReadInt16(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlInt16)result.Value); - } - } - else if (colType == typeof(int)) - { - // Int32 - FileStreamReadResult result = ReadInt32(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlInt32)) - { - // SqlInt32 - FileStreamReadResult result = ReadInt32(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlInt32)result.Value); - } - } - else if (colType == typeof(long)) - { - // Int64 - FileStreamReadResult result = ReadInt64(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlInt64)) - { - // SqlInt64 - FileStreamReadResult result = ReadInt64(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlInt64)result.Value); - } - } - else if (colType == typeof(byte)) - { - // byte - FileStreamReadResult result = ReadByte(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlByte)) - { - // SqlByte - FileStreamReadResult result = ReadByte(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlByte)result.Value); - } - } - else if (colType == typeof(char)) - { - // Char - FileStreamReadResult result = ReadChar(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(bool)) - { - // Bool - FileStreamReadResult result = ReadBoolean(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlBoolean)) - { - // SqlBoolean - FileStreamReadResult result = ReadBoolean(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlBoolean)result.Value); - } - } - else if (colType == typeof(double)) - { - // double - FileStreamReadResult result = ReadDouble(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlDouble)) - { - // SqlByte - FileStreamReadResult result = ReadDouble(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlDouble)result.Value); - } - } - else if (colType == typeof(float)) - { - // float - FileStreamReadResult result = ReadSingle(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlSingle)) - { - // SqlSingle - FileStreamReadResult result = ReadSingle(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlSingle)result.Value); - } - } - else if (colType == typeof(decimal)) - { - // Decimal - FileStreamReadResult result = ReadDecimal(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlDecimal)) - { - // SqlDecimal - FileStreamReadResult result = ReadSqlDecimal(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(DateTime)) - { - // DateTime - FileStreamReadResult result = ReadDateTime(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlDateTime)) - { - // SqlDateTime - FileStreamReadResult result = ReadDateTime(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add((SqlDateTime)result.Value); - } - } - else if (colType == typeof(DateTimeOffset)) - { - // DateTimeOffset - FileStreamReadResult result = ReadDateTimeOffset(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(TimeSpan)) - { - // TimeSpan - FileStreamReadResult result = ReadTimeSpan(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(byte[])) - { - // Byte Array - FileStreamReadResult result = ReadBytes(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull || (column.IsUdt && result.Value.Length == 0)) - { - results.Add(null); - } - else - { - results.Add(result.Value); - } - } - else if (colType == typeof(SqlBytes)) - { - // SqlBytes - FileStreamReadResult result = ReadBytes(currentFileOffset); - currentFileOffset += result.TotalLength; - results.Add(result.IsNull ? null : new SqlBytes(result.Value)); - } - else if (colType == typeof(SqlBinary)) - { - // SqlBinary - FileStreamReadResult result = ReadBytes(currentFileOffset); - currentFileOffset += result.TotalLength; - results.Add(result.IsNull ? null : new SqlBinary(result.Value)); - } - else if (colType == typeof(SqlGuid)) - { - // SqlGuid - FileStreamReadResult result = ReadBytes(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(new SqlGuid(result.Value)); - } - } - else if (colType == typeof(SqlMoney)) - { - // SqlMoney - FileStreamReadResult result = ReadDecimal(currentFileOffset); - currentFileOffset += result.TotalLength; - if (result.IsNull) - { - results.Add(null); - } - else - { - results.Add(new SqlMoney(result.Value)); - } - } - else + // Use the right read function for the type to read the data from the file + Func readFunc; + if(!readMethods.TryGetValue(colType, out readFunc)) { // Treat everything else as a string - FileStreamReadResult result = ReadString(currentFileOffset); - currentFileOffset += result.TotalLength; - results.Add(result.IsNull ? null : result.Value); - } + readFunc = ReadString; + } + FileStreamReadResult result = readFunc(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.Value); } - return results.ToArray(); + return results; } /// @@ -471,21 +146,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the short from /// A short - public FileStreamReadResult ReadInt16(long fileOffset) + public FileStreamReadResult ReadInt16(long fileOffset) { - - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 2, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - short val = default(short); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = BitConverter.ToInt16(buffer, 0); - } - - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => BitConverter.ToInt16(buffer, 0)); } /// @@ -493,19 +156,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the int from /// An int - public FileStreamReadResult ReadInt32(long fileOffset) + public FileStreamReadResult ReadInt32(long fileOffset) { - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 4, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - int val = default(int); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = BitConverter.ToInt32(buffer, 0); - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => BitConverter.ToInt32(buffer, 0)); } /// @@ -513,19 +166,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the long from /// A long - public FileStreamReadResult ReadInt64(long fileOffset) + public FileStreamReadResult ReadInt64(long fileOffset) { - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 8, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - long val = default(long); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = BitConverter.ToInt64(buffer, 0); - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => BitConverter.ToInt64(buffer, 0)); } /// @@ -533,19 +176,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the byte from /// A byte - public FileStreamReadResult ReadByte(long fileOffset) + public FileStreamReadResult ReadByte(long fileOffset) { - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 1, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - byte val = default(byte); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = buffer[0]; - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => buffer[0]); } /// @@ -553,19 +186,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the char from /// A char - public FileStreamReadResult ReadChar(long fileOffset) + public FileStreamReadResult ReadChar(long fileOffset) { - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 2, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - char val = default(char); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = BitConverter.ToChar(buffer, 0); - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => BitConverter.ToChar(buffer, 0)); } /// @@ -573,19 +196,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the bool from /// A bool - public FileStreamReadResult ReadBoolean(long fileOffset) + public FileStreamReadResult ReadBoolean(long fileOffset) { - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 1, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - bool val = default(bool); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = buffer[0] == 0x01; - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => buffer[0] == 0x1); } /// @@ -593,19 +206,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the single from /// A single - public FileStreamReadResult ReadSingle(long fileOffset) + public FileStreamReadResult ReadSingle(long fileOffset) { - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 4, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - float val = default(float); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = BitConverter.ToSingle(buffer, 0); - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => BitConverter.ToSingle(buffer, 0)); } /// @@ -613,19 +216,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the double from /// A double - public FileStreamReadResult ReadDouble(long fileOffset) + public FileStreamReadResult ReadDouble(long fileOffset) { - LengthResult length = ReadLength(fileOffset); - Debug.Assert(length.ValueLength == 0 || length.ValueLength == 8, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - double val = default(double); - if (!isNull) - { - fileStream.ReadData(buffer, length.ValueLength); - val = BitConverter.ToDouble(buffer, 0); - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + return ReadCellHelper(fileOffset, length => BitConverter.ToDouble(buffer, 0)); } /// @@ -633,23 +226,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the SqlDecimal from /// A SqlDecimal - public FileStreamReadResult ReadSqlDecimal(long offset) + public FileStreamReadResult ReadSqlDecimal(long offset) { - LengthResult length = ReadLength(offset); - Debug.Assert(length.ValueLength == 0 || (length.ValueLength - 3)%4 == 0, - string.Format("Invalid data length: {0}", length.ValueLength)); - - bool isNull = length.ValueLength == 0; - SqlDecimal val = default(SqlDecimal); - if (!isNull) + return ReadCellHelper(offset, length => { - fileStream.ReadData(buffer, length.ValueLength); - - int[] arrInt32 = new int[(length.ValueLength - 3)/4]; - Buffer.BlockCopy(buffer, 3, arrInt32, 0, length.ValueLength - 3); - val = new SqlDecimal(buffer[0], buffer[1], 1 == buffer[2], arrInt32); - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + int[] arrInt32 = new int[(length - 3) / 4]; + Buffer.BlockCopy(buffer, 3, arrInt32, 0, length - 3); + return new SqlDecimal(buffer[0], buffer[1], buffer[2] == 1, arrInt32); + }); } /// @@ -657,22 +241,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the decimal from /// A decimal - public FileStreamReadResult ReadDecimal(long offset) + public FileStreamReadResult ReadDecimal(long offset) { - LengthResult length = ReadLength(offset); - Debug.Assert(length.ValueLength%4 == 0, "Invalid data length"); - - bool isNull = length.ValueLength == 0; - decimal val = default(decimal); - if (!isNull) + return ReadCellHelper(offset, length => { - fileStream.ReadData(buffer, length.ValueLength); - - int[] arrInt32 = new int[length.ValueLength/4]; - Buffer.BlockCopy(buffer, 0, arrInt32, 0, length.ValueLength); - val = new decimal(arrInt32); - } - return new FileStreamReadResult(val, length.TotalLength, isNull); + int[] arrInt32 = new int[length / 4]; + Buffer.BlockCopy(buffer, 0, arrInt32, 0, length); + return new decimal(arrInt32); + }); } /// @@ -680,15 +256,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the DateTime from /// A DateTime - public FileStreamReadResult ReadDateTime(long offset) + public FileStreamReadResult ReadDateTime(long offset) { - FileStreamReadResult ticks = ReadInt64(offset); - DateTime val = default(DateTime); - if (!ticks.IsNull) + return ReadCellHelper(offset, length => { - val = new DateTime(ticks.Value); - } - return new FileStreamReadResult(val, ticks.TotalLength, ticks.IsNull); + long ticks = BitConverter.ToInt64(buffer, 0); + return new DateTime(ticks); + }); } /// @@ -696,27 +270,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the DateTimeOffset from /// A DateTimeOffset - public FileStreamReadResult ReadDateTimeOffset(long offset) + public FileStreamReadResult ReadDateTimeOffset(long offset) { // DateTimeOffset is represented by DateTime.Ticks followed by TimeSpan.Ticks // both as Int64 values - - // read the DateTime ticks - DateTimeOffset val = default(DateTimeOffset); - FileStreamReadResult dateTimeTicks = ReadInt64(offset); - int totalLength = dateTimeTicks.TotalLength; - if (dateTimeTicks.TotalLength > 0 && !dateTimeTicks.IsNull) - { - // read the TimeSpan ticks - FileStreamReadResult timeSpanTicks = ReadInt64(offset + dateTimeTicks.TotalLength); - Debug.Assert(!timeSpanTicks.IsNull, "TimeSpan ticks cannot be null if DateTime ticks are not null!"); - - totalLength += timeSpanTicks.TotalLength; - - // build the DateTimeOffset - val = new DateTimeOffset(new DateTime(dateTimeTicks.Value), new TimeSpan(timeSpanTicks.Value)); - } - return new FileStreamReadResult(val, totalLength, dateTimeTicks.IsNull); + return ReadCellHelper(offset, length => { + long dtTicks = BitConverter.ToInt64(buffer, 0); + long dtOffset = BitConverter.ToInt64(buffer, 8); + return new DateTimeOffset(new DateTime(dtTicks), new TimeSpan(dtOffset)); + }); } /// @@ -724,15 +286,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the TimeSpan from /// A TimeSpan - public FileStreamReadResult ReadTimeSpan(long offset) + public FileStreamReadResult ReadTimeSpan(long offset) { - FileStreamReadResult timeSpanTicks = ReadInt64(offset); - TimeSpan val = default(TimeSpan); - if (!timeSpanTicks.IsNull) + return ReadCellHelper(offset, length => { - val = new TimeSpan(timeSpanTicks.Value); - } - return new FileStreamReadResult(val, timeSpanTicks.TotalLength, timeSpanTicks.IsNull); + long ticks = BitConverter.ToInt64(buffer, 0); + return new TimeSpan(ticks); + }); } /// @@ -740,24 +300,12 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the string from /// A string - public FileStreamReadResult ReadString(long offset) + public FileStreamReadResult ReadString(long offset) { - LengthResult fieldLength = ReadLength(offset); - Debug.Assert(fieldLength.ValueLength%2 == 0, "Invalid data length"); - - if (fieldLength.ValueLength == 0) // there is no data - { - // If the total length is 5 (5 bytes for length, 0 for value), then the string is empty - // Otherwise, the string is null - bool isNull = fieldLength.TotalLength != 5; - return new FileStreamReadResult(isNull ? null : string.Empty, - fieldLength.TotalLength, isNull); - } - - // positive length - AssureBufferLength(fieldLength.ValueLength); - fileStream.ReadData(buffer, fieldLength.ValueLength); - return new FileStreamReadResult(Encoding.Unicode.GetString(buffer, 0, fieldLength.ValueLength), fieldLength.TotalLength, false); + return ReadCellHelper(offset, length => + length > 0 + ? Encoding.Unicode.GetString(buffer, 0, length) + : string.Empty, totalLength => totalLength == 1); } /// @@ -765,23 +313,54 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// /// Offset into the file to read the bytes from /// A byte array - public FileStreamReadResult ReadBytes(long offset) + public FileStreamReadResult ReadBytes(long offset) { - LengthResult fieldLength = ReadLength(offset); - - if (fieldLength.ValueLength == 0) + return ReadCellHelper(offset, length => { - // If the total length is 5 (5 bytes for length, 0 for value), then the byte array is 0x - // Otherwise, the byte array is null - bool isNull = fieldLength.TotalLength != 5; - return new FileStreamReadResult(isNull ? null : new byte[0], - fieldLength.TotalLength, isNull); - } + byte[] output = new byte[length]; + Buffer.BlockCopy(buffer, 0, output, 0, length); + return output; + }, totalLength => totalLength == 1, + bytes => + { + StringBuilder sb = new StringBuilder("0x"); + foreach (byte b in bytes) + { + sb.AppendFormat("{0:X2}", b); + } + return sb.ToString(); + }); + } - // positive length - byte[] val = new byte[fieldLength.ValueLength]; - fileStream.ReadData(val, fieldLength.ValueLength); - return new FileStreamReadResult(val, fieldLength.TotalLength, false); + /// + /// Reads the bytes that make up a GUID at the offset provided + /// + /// Offset into the file to read the bytes from + /// A guid type object + public FileStreamReadResult ReadGuid(long offset) + { + return ReadCellHelper(offset, length => + { + byte[] output = new byte[length]; + Buffer.BlockCopy(buffer, 0, output, 0, length); + return new SqlGuid(output); + }, totalLength => totalLength == 1); + } + + /// + /// Reads a SqlMoney type from the offset provided + /// into a + /// + /// + /// A sql money type object + public FileStreamReadResult ReadMoney(long offset) + { + return ReadCellHelper(offset, length => + { + int[] arrInt32 = new int[length / 4]; + Buffer.BlockCopy(buffer, 0, arrInt32, 0, length); + return new SqlMoney(new decimal(arrInt32)); + }); } /// @@ -813,6 +392,58 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage #endregion + #region Private Helpers + + /// + /// Creates a new buffer that is of the specified length if the buffer is not already + /// at least as long as specified. + /// + /// The minimum buffer size + private void AssureBufferLength(int newBufferLength) + { + if (buffer.Length < newBufferLength) + { + buffer = new byte[newBufferLength]; + } + } + + /// + /// Reads the value of a cell from the file wrapper, checks to see if it null using + /// , and converts it to the proper output type using + /// . + /// + /// Offset into the file to read from + /// Function to use to convert the buffer to the target type + /// + /// If provided, this function will be used to determine if the value is null + /// + /// Optional function to use to convert the object to a string. + /// The expected type of the cell. Used to keep the code honest + /// The object, a display value, and the length of the value + its length + private FileStreamReadResult ReadCellHelper(long offset, Func convertFunc, Func isNullFunc = null, Func toStringFunc = null) + { + LengthResult length = ReadLength(offset); + DbCellValue result = new DbCellValue(); + + if (isNullFunc == null ? length.ValueLength == 0 : isNullFunc(length.TotalLength)) + { + result.RawObject = null; + result.DisplayValue = null; + } + else + { + AssureBufferLength(length.ValueLength); + fileStream.ReadData(buffer, length.ValueLength); + T resultObject = convertFunc(length.ValueLength); + result.RawObject = resultObject; + result.DisplayValue = toStringFunc == null ? result.RawObject.ToString() : toStringFunc(resultObject); + } + + return new FileStreamReadResult(result, length.TotalLength); + } + + #endregion + /// /// Internal struct used for representing the length of a field from the file /// @@ -837,19 +468,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage } } - /// - /// Creates a new buffer that is of the specified length if the buffer is not already - /// at least as long as specified. - /// - /// The minimum buffer size - private void AssureBufferLength(int newBufferLength) - { - if (buffer.Length < newBufferLength) - { - buffer = new byte[newBufferLength]; - } - } - #region IDisposable Implementation private bool disposed; diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs index dfc36487..2e4360d2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs @@ -4,10 +4,10 @@ // using System; +using System.Collections.Generic; using System.Data.SqlTypes; using System.Diagnostics; using System.IO; -using System.Linq; using System.Text; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.Utility; @@ -19,14 +19,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// public class ServiceBufferFileStreamWriter : IFileStreamWriter { - #region Properties + private const int DefaultBufferLength = 8192; - public const int DefaultBufferLength = 8192; - - private int MaxCharsToStore { get; set; } - private int MaxXmlCharsToStore { get; set; } + #region Member Variables + + private readonly IFileStreamWrapper fileStream; + private readonly int maxCharsToStore; + private readonly int maxXmlCharsToStore; - private IFileStreamWrapper FileStream { get; set; } private byte[] byteBuffer; private readonly short[] shortBuffer; private readonly int[] intBuffer; @@ -35,6 +35,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage private readonly double[] doubleBuffer; private readonly float[] floatBuffer; + /// + /// Functions to use for writing various types to a file + /// + private readonly Dictionary> writeMethods; + #endregion /// @@ -47,8 +52,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage public ServiceBufferFileStreamWriter(IFileStreamWrapper fileWrapper, string fileName, int maxCharsToStore, int maxXmlCharsToStore) { // open file for reading/writing - FileStream = fileWrapper; - FileStream.Init(fileName, DefaultBufferLength, FileAccess.ReadWrite); + fileStream = fileWrapper; + fileStream.Init(fileName, DefaultBufferLength, FileAccess.ReadWrite); // create internal buffer byteBuffer = new byte[DefaultBufferLength]; @@ -63,8 +68,42 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage floatBuffer = new float[1]; // Store max chars to store - MaxCharsToStore = maxCharsToStore; - MaxXmlCharsToStore = maxXmlCharsToStore; + this.maxCharsToStore = maxCharsToStore; + this.maxXmlCharsToStore = maxXmlCharsToStore; + + // Define what methods to use to write a type to the file + writeMethods = new Dictionary> + { + {typeof(string), val => WriteString((string) val)}, + {typeof(short), val => WriteInt16((short) val)}, + {typeof(int), val => WriteInt32((int) val)}, + {typeof(long), val => WriteInt64((long) val)}, + {typeof(byte), val => WriteByte((byte) val)}, + {typeof(char), val => WriteChar((char) val)}, + {typeof(bool), val => WriteBoolean((bool) val)}, + {typeof(double), val => WriteDouble((double) val) }, + {typeof(float), val => WriteSingle((float) val) }, + {typeof(decimal), val => WriteDecimal((decimal) val) }, + {typeof(DateTime), val => WriteDateTime((DateTime) val) }, + {typeof(DateTimeOffset), val => WriteDateTimeOffset((DateTimeOffset) val) }, + {typeof(TimeSpan), val => WriteTimeSpan((TimeSpan) val) }, + {typeof(byte[]), val => WriteBytes((byte[]) val)}, + + {typeof(SqlString), val => WriteNullable((SqlString) val, obj => WriteString((string) obj))}, + {typeof(SqlInt16), val => WriteNullable((SqlInt16) val, obj => WriteInt16((short) obj))}, + {typeof(SqlInt32), val => WriteNullable((SqlInt32) val, obj => WriteInt32((int) obj))}, + {typeof(SqlInt64), val => WriteNullable((SqlInt64) val, obj => WriteInt64((long) obj)) }, + {typeof(SqlByte), val => WriteNullable((SqlByte) val, obj => WriteByte((byte) obj)) }, + {typeof(SqlBoolean), val => WriteNullable((SqlBoolean) val, obj => WriteBoolean((bool) obj)) }, + {typeof(SqlDouble), val => WriteNullable((SqlDouble) val, obj => WriteDouble((double) obj)) }, + {typeof(SqlSingle), val => WriteNullable((SqlSingle) val, obj => WriteSingle((float) obj)) }, + {typeof(SqlDecimal), val => WriteNullable((SqlDecimal) val, obj => WriteSqlDecimal((SqlDecimal) obj)) }, + {typeof(SqlDateTime), val => WriteNullable((SqlDateTime) val, obj => WriteDateTime((DateTime) obj)) }, + {typeof(SqlBytes), val => WriteNullable((SqlBytes) val, obj => WriteBytes((byte[]) obj)) }, + {typeof(SqlBinary), val => WriteNullable((SqlBinary) val, obj => WriteBytes((byte[]) obj)) }, + {typeof(SqlGuid), val => WriteNullable((SqlGuid) val, obj => WriteGuid((Guid) obj)) }, + {typeof(SqlMoney), val => WriteNullable((SqlMoney) val, obj => WriteMoney((SqlMoney) obj)) } + }; } #region IFileStreamWriter Implementation @@ -76,22 +115,20 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// Number of bytes used to write the row public int WriteRow(StorageDataReader reader) { - // Determine if we have any long fields - bool hasLongFields = reader.Columns.Any(column => column.IsLong); - + // Read the values in from the db object[] values = new object[reader.Columns.Length]; - int rowBytes = 0; - if (!hasLongFields) + if (!reader.HasLongColumns) { // get all record values in one shot if there are no extra long fields reader.GetValues(values); } // Loop over all the columns and write the values to the temp file + int rowBytes = 0; for (int i = 0; i < reader.Columns.Length; i++) { DbColumnWrapper ci = reader.Columns[i]; - if (hasLongFields) + if (reader.HasLongColumns) { if (reader.IsDBNull(i)) { @@ -111,18 +148,18 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // this is a long field if (ci.IsBytes) { - values[i] = reader.GetBytesWithMaxCapacity(i, MaxCharsToStore); + values[i] = reader.GetBytesWithMaxCapacity(i, maxCharsToStore); } else if (ci.IsChars) { - Debug.Assert(MaxCharsToStore > 0); + Debug.Assert(maxCharsToStore > 0); values[i] = reader.GetCharsWithMaxCapacity(i, - ci.IsXml ? MaxXmlCharsToStore : MaxCharsToStore); + ci.IsXml ? maxXmlCharsToStore : maxCharsToStore); } else if (ci.IsXml) { - Debug.Assert(MaxXmlCharsToStore > 0); - values[i] = reader.GetXmlWithMaxCapacity(i, MaxXmlCharsToStore); + Debug.Assert(maxXmlCharsToStore > 0); + values[i] = reader.GetXmlWithMaxCapacity(i, maxXmlCharsToStore); } else { @@ -133,8 +170,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage } } - Type tVal = values[i].GetType(); // get true type of the object + // Get true type of the object + Type tVal = values[i].GetType(); + // Write the object to a file if (tVal == typeof(DBNull)) { rowBytes += WriteNull(); @@ -148,272 +187,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage rowBytes += WriteString(val); } - if (tVal == typeof(string)) + // Use the appropriate writing method for the type + Func writeMethod; + if (writeMethods.TryGetValue(tVal, out writeMethod)) { - // String - most frequently used data type - string val = (string)values[i]; - rowBytes += WriteString(val); - } - else if (tVal == typeof(SqlString)) - { - // SqlString - SqlString val = (SqlString)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteString(val.Value); - } - } - else if (tVal == typeof(short)) - { - // Int16 - short val = (short)values[i]; - rowBytes += WriteInt16(val); - } - else if (tVal == typeof(SqlInt16)) - { - // SqlInt16 - SqlInt16 val = (SqlInt16)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteInt16(val.Value); - } - } - else if (tVal == typeof(int)) - { - // Int32 - int val = (int)values[i]; - rowBytes += WriteInt32(val); - } - else if (tVal == typeof(SqlInt32)) - { - // SqlInt32 - SqlInt32 val = (SqlInt32)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteInt32(val.Value); - } - } - else if (tVal == typeof(long)) - { - // Int64 - long val = (long)values[i]; - rowBytes += WriteInt64(val); - } - else if (tVal == typeof(SqlInt64)) - { - // SqlInt64 - SqlInt64 val = (SqlInt64)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteInt64(val.Value); - } - } - else if (tVal == typeof(byte)) - { - // Byte - byte val = (byte)values[i]; - rowBytes += WriteByte(val); - } - else if (tVal == typeof(SqlByte)) - { - // SqlByte - SqlByte val = (SqlByte)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteByte(val.Value); - } - } - else if (tVal == typeof(char)) - { - // Char - char val = (char)values[i]; - rowBytes += WriteChar(val); - } - else if (tVal == typeof(bool)) - { - // Boolean - bool val = (bool)values[i]; - rowBytes += WriteBoolean(val); - } - else if (tVal == typeof(SqlBoolean)) - { - // SqlBoolean - SqlBoolean val = (SqlBoolean)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteBoolean(val.Value); - } - } - else if (tVal == typeof(double)) - { - // Double - double val = (double)values[i]; - rowBytes += WriteDouble(val); - } - else if (tVal == typeof(SqlDouble)) - { - // SqlDouble - SqlDouble val = (SqlDouble)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteDouble(val.Value); - } - } - else if (tVal == typeof(SqlSingle)) - { - // SqlSingle - SqlSingle val = (SqlSingle)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteSingle(val.Value); - } - } - else if (tVal == typeof(decimal)) - { - // Decimal - decimal val = (decimal)values[i]; - rowBytes += WriteDecimal(val); - } - else if (tVal == typeof(SqlDecimal)) - { - // SqlDecimal - SqlDecimal val = (SqlDecimal)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteSqlDecimal(val); - } - } - else if (tVal == typeof(DateTime)) - { - // DateTime - DateTime val = (DateTime)values[i]; - rowBytes += WriteDateTime(val); - } - else if (tVal == typeof(DateTimeOffset)) - { - // DateTimeOffset - DateTimeOffset val = (DateTimeOffset)values[i]; - rowBytes += WriteDateTimeOffset(val); - } - else if (tVal == typeof(SqlDateTime)) - { - // SqlDateTime - SqlDateTime val = (SqlDateTime)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteDateTime(val.Value); - } - } - else if (tVal == typeof(TimeSpan)) - { - // TimeSpan - TimeSpan val = (TimeSpan)values[i]; - rowBytes += WriteTimeSpan(val); - } - else if (tVal == typeof(byte[])) - { - // Bytes - byte[] val = (byte[])values[i]; - rowBytes += WriteBytes(val, val.Length); - } - else if (tVal == typeof(SqlBytes)) - { - // SqlBytes - SqlBytes val = (SqlBytes)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteBytes(val.Value, val.Value.Length); - } - } - else if (tVal == typeof(SqlBinary)) - { - // SqlBinary - SqlBinary val = (SqlBinary)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteBytes(val.Value, val.Value.Length); - } - } - else if (tVal == typeof(SqlGuid)) - { - // SqlGuid - SqlGuid val = (SqlGuid)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - byte[] bytesVal = val.ToByteArray(); - rowBytes += WriteBytes(bytesVal, bytesVal.Length); - } - } - else if (tVal == typeof(SqlMoney)) - { - // SqlMoney - SqlMoney val = (SqlMoney)values[i]; - if (val.IsNull) - { - rowBytes += WriteNull(); - } - else - { - rowBytes += WriteDecimal(val.Value); - } + rowBytes += writeMethod(values[i]); } else { - // treat everything else as string - string val = values[i].ToString(); - rowBytes += WriteString(val); + rowBytes += WriteString(values[i].ToString()); } } } @@ -430,7 +212,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage public int WriteNull() { byteBuffer[0] = 0x00; - return FileStream.WriteData(byteBuffer, 1); + return fileStream.WriteData(byteBuffer, 1); } /// @@ -442,7 +224,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x02; // length shortBuffer[0] = val; Buffer.BlockCopy(shortBuffer, 0, byteBuffer, 1, 2); - return FileStream.WriteData(byteBuffer, 3); + return fileStream.WriteData(byteBuffer, 3); } /// @@ -454,7 +236,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x04; // length intBuffer[0] = val; Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); - return FileStream.WriteData(byteBuffer, 5); + return fileStream.WriteData(byteBuffer, 5); } /// @@ -466,7 +248,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x08; // length longBuffer[0] = val; Buffer.BlockCopy(longBuffer, 0, byteBuffer, 1, 8); - return FileStream.WriteData(byteBuffer, 9); + return fileStream.WriteData(byteBuffer, 9); } /// @@ -478,7 +260,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x02; // length charBuffer[0] = val; Buffer.BlockCopy(charBuffer, 0, byteBuffer, 1, 2); - return FileStream.WriteData(byteBuffer, 3); + return fileStream.WriteData(byteBuffer, 3); } /// @@ -489,7 +271,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { byteBuffer[0] = 0x01; // length byteBuffer[1] = (byte) (val ? 0x01 : 0x00); - return FileStream.WriteData(byteBuffer, 2); + return fileStream.WriteData(byteBuffer, 2); } /// @@ -500,7 +282,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage { byteBuffer[0] = 0x01; // length byteBuffer[1] = val; - return FileStream.WriteData(byteBuffer, 2); + return fileStream.WriteData(byteBuffer, 2); } /// @@ -512,7 +294,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x04; // length floatBuffer[0] = val; Buffer.BlockCopy(floatBuffer, 0, byteBuffer, 1, 4); - return FileStream.WriteData(byteBuffer, 5); + return fileStream.WriteData(byteBuffer, 5); } /// @@ -524,7 +306,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[0] = 0x08; // length doubleBuffer[0] = val; Buffer.BlockCopy(doubleBuffer, 0, byteBuffer, 1, 8); - return FileStream.WriteData(byteBuffer, 9); + return fileStream.WriteData(byteBuffer, 9); } /// @@ -548,7 +330,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // data value Buffer.BlockCopy(arrInt32, 0, byteBuffer, 3, iLen - 3); - iTotalLen += FileStream.WriteData(byteBuffer, iLen); + iTotalLen += fileStream.WriteData(byteBuffer, iLen); return iTotalLen; // len+data } @@ -564,7 +346,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage int iTotalLen = WriteLength(iLen); // length Buffer.BlockCopy(arrInt32, 0, byteBuffer, 0, iLen); - iTotalLen += FileStream.WriteData(byteBuffer, iLen); + iTotalLen += fileStream.WriteData(byteBuffer, iLen); return iTotalLen; // len+data } @@ -584,9 +366,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// Number of bytes used to store the DateTimeOffset public int WriteDateTimeOffset(DateTimeOffset dtoVal) { - // DateTimeOffset gets written as a DateTime + TimeOffset - // both represented as 'Ticks' written as Int64's - return WriteInt64(dtoVal.Ticks) + WriteInt64(dtoVal.Offset.Ticks); + // Write the length, which is the 2*sizeof(long) + byteBuffer[0] = 0x10; // length (16) + + // Write the two longs, the datetime and the offset + long[] longBufferOffset = new long[2]; + longBufferOffset[0] = dtoVal.Ticks; + longBufferOffset[1] = dtoVal.Offset.Ticks; + Buffer.BlockCopy(longBufferOffset, 0, byteBuffer, 1, 16); + return fileStream.WriteData(byteBuffer, 17); } /// @@ -618,7 +406,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage byteBuffer[3] = 0x00; byteBuffer[4] = 0x00; - iTotalLen = FileStream.WriteData(byteBuffer, 5); + iTotalLen = fileStream.WriteData(byteBuffer, 5); } else { @@ -627,7 +415,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // convert char array into byte array and write it out iTotalLen = WriteLength(bytes.Length); - iTotalLen += FileStream.WriteData(bytes, bytes.Length); + iTotalLen += fileStream.WriteData(bytes, bytes.Length); } return iTotalLen; // len+data } @@ -636,32 +424,76 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// Writes a byte[] to the file /// /// Number of bytes used to store the byte[] - public int WriteBytes(byte[] bytesVal, int iLen) + public int WriteBytes(byte[] bytesVal) { Validate.IsNotNull(nameof(bytesVal), bytesVal); int iTotalLen; - if (0 == iLen) // special case of 0 length byte array "0x" + if (bytesVal.Length == 0) // special case of 0 length byte array "0x" { - iLen = 5; - - AssureBufferLength(iLen); + AssureBufferLength(5); byteBuffer[0] = 0xFF; byteBuffer[1] = 0x00; byteBuffer[2] = 0x00; byteBuffer[3] = 0x00; byteBuffer[4] = 0x00; - iTotalLen = FileStream.WriteData(byteBuffer, iLen); + iTotalLen = fileStream.WriteData(byteBuffer, 5); } else { - iTotalLen = WriteLength(iLen); - iTotalLen += FileStream.WriteData(bytesVal, iLen); + iTotalLen = WriteLength(bytesVal.Length); + iTotalLen += fileStream.WriteData(bytesVal, bytesVal.Length); } return iTotalLen; // len+data } + /// + /// Stores a GUID value to the file by treating it as a byte array + /// + /// The GUID to write to the file + /// Number of bytes written to the file + public int WriteGuid(Guid val) + { + byte[] guidBytes = val.ToByteArray(); + return WriteBytes(guidBytes); + } + + /// + /// Stores a SqlMoney value to the file by treating it as a decimal + /// + /// The SqlMoney value to write to the file + /// Number of bytes written to the file + public int WriteMoney(SqlMoney val) + { + return WriteDecimal(val.Value); + } + + /// + /// Flushes the internal buffer to the file stream + /// + public void FlushBuffer() + { + fileStream.Flush(); + } + + #endregion + + #region Private Helpers + + /// + /// Creates a new buffer that is of the specified length if the buffer is not already + /// at least as long as specified. + /// + /// The minimum buffer size + private void AssureBufferLength(int newBufferLength) + { + if (newBufferLength > byteBuffer.Length) + { + byteBuffer = new byte[byteBuffer.Length]; + } + } + /// /// Writes the length of the field using the appropriate number of bytes (ie, 1 if the /// length is <255, 5 if the length is >=255) @@ -675,7 +507,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage int iTmp = iLen & 0x000000FF; byteBuffer[0] = Convert.ToByte(iTmp); - return FileStream.WriteData(byteBuffer, 1); + return fileStream.WriteData(byteBuffer, 1); } // The length won't fit in 1 byte, so we need to use 1 byte to signify that the length // is a full 4 bytes. @@ -684,27 +516,24 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // convert int32 into array of bytes intBuffer[0] = iLen; Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); - return FileStream.WriteData(byteBuffer, 5); + return fileStream.WriteData(byteBuffer, 5); } /// - /// Flushes the internal buffer to the file stream + /// Writes a Nullable type (generally a Sql* type) to the file. The function provided by + /// is used to write to the file if + /// is not null. is used if is null. /// - public void FlushBuffer() + /// The value to write to the file + /// The function to use if val is not null + /// Number of bytes used to write value to the file + private int WriteNullable(INullable val, Func valueWriteFunc) { - FileStream.Flush(); + return val.IsNull ? WriteNull() : valueWriteFunc(val); } #endregion - private void AssureBufferLength(int newBufferLength) - { - if (newBufferLength > byteBuffer.Length) - { - byteBuffer = new byte[byteBuffer.Length]; - } - } - #region IDisposable Implementation private bool disposed; @@ -724,8 +553,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage if (disposing) { - FileStream.Flush(); - FileStream.Dispose(); + fileStream.Flush(); + fileStream.Dispose(); } disposed = true; diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs index 1e45d437..cc5d1443 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs @@ -57,6 +57,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage // Read the columns into a set of wrappers Columns = DbDataReader.GetColumnSchema().Select(column => new DbColumnWrapper(column)).ToArray(); + HasLongColumns = Columns.Any(column => column.IsLong); } #region Properties @@ -71,6 +72,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage /// public DbDataReader DbDataReader { get; private set; } + /// + /// Whether or not any of the columns of this reader are 'long', such as nvarchar(max) + /// + public bool HasLongColumns { get; private set; } + #endregion #region DbDataReader Methods diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index 5d4e1ad5..562a3a61 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -282,17 +282,19 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { // get the requested resultSet from query Batch selectedBatch = result.Batches[saveParams.BatchIndex]; - ResultSet selectedResultSet = (selectedBatch.ResultSets.ToList())[saveParams.ResultSetIndex]; + ResultSet selectedResultSet = selectedBatch.ResultSets.ToList()[saveParams.ResultSetIndex]; if (saveParams.IncludeHeaders) { // write column names to csv - await csvFile.WriteLineAsync( string.Join( ",", selectedResultSet.Columns.Select( column => SaveResults.EncodeCsvField(column.ColumnName) ?? string.Empty))); + await csvFile.WriteLineAsync(string.Join(",", + selectedResultSet.Columns.Select(column => SaveResults.EncodeCsvField(column.ColumnName) ?? string.Empty))); } // write rows to csv foreach (var row in selectedResultSet.Rows) { - await csvFile.WriteLineAsync( string.Join( ",", row.Select( field => SaveResults.EncodeCsvField((field != null) ? field.ToString(): string.Empty)))); + await csvFile.WriteLineAsync(string.Join(",", + row.Select(field => SaveResults.EncodeCsvField(field ?? string.Empty)))); } } @@ -336,23 +338,26 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution // get the requested resultSet from query Batch selectedBatch = result.Batches[saveParams.BatchIndex]; - ResultSet selectedResultSet = (selectedBatch.ResultSets.ToList())[saveParams.ResultSetIndex]; + ResultSet selectedResultSet = selectedBatch.ResultSets.ToList()[saveParams.ResultSetIndex]; // write each row to JSON foreach (var row in selectedResultSet.Rows) { jsonWriter.WriteStartObject(); - foreach (var field in row.Select((value,i) => new {value, i})) + for (int i = 0; i < row.Length; i++) { - jsonWriter.WritePropertyName(selectedResultSet.Columns[field.i].ColumnName); - if (field.value != null) - { - jsonWriter.WriteValue(field.value); - } - else + DbColumnWrapper col = selectedResultSet.Columns[i]; + string val = row[i]; + + jsonWriter.WritePropertyName(col.ColumnName); + if (val == null) { jsonWriter.WriteNull(); - } + } + else + { + jsonWriter.WriteValue(val); + } } jsonWriter.WriteEndObject(); } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index b4dd411d..58933532 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -112,9 +112,13 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// /// The rows of this result set /// - public IEnumerable Rows + public IEnumerable Rows { - get { return FileOffsets.Select(offset => fileStreamReader.ReadRow(offset, Columns)); } + get + { + return FileOffsets.Select( + offset => fileStreamReader.ReadRow(offset, Columns).Select(cell => cell.DisplayValue).ToArray()); + } } #endregion @@ -151,7 +155,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); // Iterate over the rows we need and process them into output - object[][] rows = rowOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)).ToArray(); + string[][] rows = rowOffsets.Select(rowOffset => + fileStreamReader.ReadRow(rowOffset, Columns).Select(cell => cell.DisplayValue).ToArray()) + .ToArray(); // Retrieve the subset of the results as per the request return new ResultSetSubset diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index c7a00c09..56273be4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -184,6 +184,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution connectionMock.Protected() .Setup("CreateDbCommand") .Returns(CreateTestCommand(data, throwOnRead)); + connectionMock.Setup(dbc => dbc.Open()) + .Callback(() => connectionMock.SetupGet(dbc => dbc.State).Returns(ConnectionState.Open)); + connectionMock.Setup(dbc => dbc.Close()) + .Callback(() => connectionMock.SetupGet(dbc => dbc.State).Returns(ConnectionState.Closed)); return connectionMock.Object; } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs index b10a7f92..87076204 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs @@ -14,7 +14,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage { public class ReaderWriterPairTest { - private static void VerifyReadWrite(int valueLength, T value, Func writeFunc, Func> readFunc) + private static void VerifyReadWrite(int valueLength, T value, Func writeFunc, Func readFunc) { // Setup: Create a mock file stream wrapper Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); @@ -29,16 +29,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage } // ... And read the type T back - FileStreamReadResult outValue; + FileStreamReadResult outValue; using (ServiceBufferFileStreamReader reader = new ServiceBufferFileStreamReader(mockWrapper, "abc")) { outValue = readFunc(reader); } // Then: - Assert.Equal(value, outValue.Value); + Assert.Equal(value, outValue.Value.RawObject); Assert.Equal(valueLength, outValue.TotalLength); - Assert.False(outValue.IsNull); + Assert.NotNull(outValue.Value); } finally { @@ -200,7 +200,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage }; foreach (DateTimeOffset value in testValues) { - VerifyReadWrite((sizeof(long) + 1)*2, value, (writer, val) => writer.WriteDateTimeOffset(val), reader => reader.ReadDateTimeOffset(0)); + VerifyReadWrite(sizeof(long)*2 + 1, value, (writer, val) => writer.WriteDateTimeOffset(val), reader => reader.ReadDateTimeOffset(0)); } } @@ -267,7 +267,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage { // Then: // ... I should get an argument null exception - Assert.Throws(() => writer.WriteBytes(null, 0)); + Assert.Throws(() => writer.WriteBytes(null)); } } @@ -289,7 +289,38 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage byte[] value = sb.ToArray(); int lengthLength = length == 0 || length > 255 ? 5 : 1; int valueLength = sizeof(byte)*length + lengthLength; - VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteBytes(value, length), reader => reader.ReadBytes(0)); + VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteBytes(value), reader => reader.ReadBytes(0)); + } + + [Fact] + public void GuidTest() + { + // Setup: + // ... Create some test values + // NOTE: We are doing these here instead of InlineData because Guid type can't be written as constant expressions + Guid[] guids = + { + Guid.Empty, Guid.NewGuid(), Guid.NewGuid() + }; + foreach (Guid guid in guids) + { + VerifyReadWrite(guid.ToByteArray().Length + 1, new SqlGuid(guid), (writer, val) => writer.WriteGuid(guid), reader => reader.ReadGuid(0)); + } + } + + [Fact] + public void MoneyTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because SqlMoney can't be written as a constant expression + SqlMoney[] monies = + { + SqlMoney.Zero, SqlMoney.MinValue, SqlMoney.MaxValue, new SqlMoney(1.02) + }; + foreach (SqlMoney money in monies) + { + VerifyReadWrite(sizeof(decimal) + 1, money, (writer, val) => writer.WriteMoney(money), reader => reader.ReadMoney(0)); + } } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs index 1a50dd55..8212eba3 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -4,6 +4,7 @@ // using System; +using System.Linq; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; @@ -17,6 +18,48 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { public class SubsetTests { + #region ResultSet Class Tests + + [Theory] + [InlineData(0,2)] + [InlineData(0,20)] + [InlineData(1,2)] + public void ResultSetValidTest(int startRow, int rowCount) + { + // Setup: + // ... I have a batch that has been executed + Batch b = Common.GetBasicExecutedBatch(); + + // If: + // ... I have a result set and I ask for a subset with valid arguments + ResultSet rs = b.ResultSets.First(); + ResultSetSubset subset = rs.GetSubset(startRow, rowCount).Result; + + // Then: + // ... I should get the requested number of rows back + Assert.Equal(Math.Min(rowCount, Common.StandardTestData.Length), subset.RowCount); + Assert.Equal(Math.Min(rowCount, Common.StandardTestData.Length), subset.Rows.Length); + } + + [Theory] + [InlineData(-1, 2)] // Invalid start index, too low + [InlineData(10, 2)] // Invalid start index, too high + [InlineData(0, -1)] // Invalid row count, too low + [InlineData(0, 0)] // Invalid row count, zero + public void ResultSetInvalidParmsTest(int rowStartIndex, int rowCount) + { + // If: + // I have an executed batch with a resultset in it and request invalid result set from it + Batch b = Common.GetBasicExecutedBatch(); + ResultSet rs = b.ResultSets.First(); + + // Then: + // ... It should throw an exception + Assert.ThrowsAsync(() => rs.GetSubset(rowStartIndex, rowCount)).Wait(); + } + + #endregion + #region Batch Class Tests [Theory] @@ -37,13 +80,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution } [Theory] - [InlineData(-1, 0, 2)] // Invalid result set, too low - [InlineData(2, 0, 2)] // Invalid result set, too high - [InlineData(0, -1, 2)] // Invalid start index, too low - [InlineData(0, 10, 2)] // Invalid start index, too high - [InlineData(0, 0, -1)] // Invalid row count, too low - [InlineData(0, 0, 0)] // Invalid row count, zero - public void BatchSubsetInvalidParamsTest(int resultSetIndex, int rowStartInex, int rowCount) + [InlineData(-1)] // Invalid result set, too low + [InlineData(2)] // Invalid result set, too high + public void BatchSubsetInvalidParamsTest(int resultSetIndex) { // If I have an executed batch Batch b = Common.GetBasicExecutedBatch(); @@ -51,7 +90,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... And I ask for a subset with an invalid result set index // Then: // ... It should throw an exception - Assert.ThrowsAsync(() => b.GetSubset(resultSetIndex, rowStartInex, rowCount)).Wait(); + Assert.ThrowsAsync(() => b.GetSubset(resultSetIndex, 0, 2)).Wait(); } #endregion @@ -95,7 +134,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... I have a query that has results (doesn't matter what) var queryService =Common.GetPrimedExecutionService( Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true); - var executeParams = new QueryExecuteParams {QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri}; + var executeParams = new QueryExecuteParams {QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri}; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); await queryService.HandleExecuteRequest(executeParams, executeRequest.Object); @@ -141,7 +180,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... I have a query that hasn't finished executing (doesn't matter what) var queryService = Common.GetPrimedExecutionService( Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true); - var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri }; + var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; @@ -168,7 +207,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... I have a query that doesn't have any result sets var queryService = Common.GetPrimedExecutionService( Common.CreateMockFactory(null, false), true); - var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri }; + var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; var executeRequest = RequestContextMocks.SetupRequestContextMock(null, QueryExecuteCompleteEvent.Type, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); @@ -191,7 +230,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution #region Mocking - private Mock> GetQuerySubsetResultContextMock( + private static Mock> GetQuerySubsetResultContextMock( Action resultCallback, Action errorCallback) { @@ -218,7 +257,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution return requestContext; } - private void VerifyQuerySubsetCallCount(Mock> mock, Times sendResultCalls, + private static void VerifyQuerySubsetCallCount(Mock> mock, Times sendResultCalls, Times sendErrorCalls) { mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls);