diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs index 38528b60..69250afe 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs @@ -1,7 +1,6 @@ -// +// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -// using System; using System.Collections.Generic; @@ -11,18 +10,60 @@ using System.Data.SqlClient; using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.SqlTools.EditorServices.Utility; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { /// /// This class represents a batch within a query /// - public class Batch + public class Batch : IDisposable { private const string RowsAffectedFormat = "({0} row(s) affected)"; + #region Member Variables + + /// + /// For IDisposable implementation, whether or not this has been disposed + /// + private bool disposed; + + /// + /// Factory for creating readers/writrs for the output of the batch + /// + private readonly IFileStreamFactory outputFileFactory; + + /// + /// Internal representation of the messages so we can modify internally + /// + private readonly List resultMessages; + + /// + /// Internal representation of the result sets so we can modify internally + /// + private readonly List resultSets; + + #endregion + + internal Batch(string batchText, int startLine, IFileStreamFactory outputFileFactory) + { + // Sanity check for input + Validate.IsNotNullOrEmptyString(nameof(batchText), batchText); + Validate.IsNotNull(nameof(outputFileFactory), outputFileFactory); + + // Initialize the internal state + BatchText = batchText; + StartLine = startLine - 1; // -1 to make sure that the line number of the batch is 0-indexed, since SqlParser gives 1-indexed line numbers + HasExecuted = false; + resultSets = new List(); + resultMessages = new List(); + this.outputFileFactory = outputFileFactory; + } + #region Properties + /// /// The text of batch that will be executed /// @@ -38,11 +79,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public bool HasExecuted { get; set; } - /// - /// Internal representation of the messages so we can modify internally - /// - private List resultMessages; - /// /// Messages that have come back from the server /// @@ -51,11 +87,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution get { return resultMessages; } } - /// - /// Internal representation of the result sets so we can modify internally - /// - private List resultSets; - /// /// The result sets of the batch execution /// @@ -75,7 +106,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { ColumnInfo = set.Columns, Id = index, - RowCount = set.Rows.Count + RowCount = set.RowCount }).ToArray(); } } @@ -87,21 +118,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #endregion - public Batch(string batchText, int startLine) - { - // Sanity check for input - if (string.IsNullOrEmpty(batchText)) - { - throw new ArgumentNullException(nameof(batchText), "Query text cannot be null"); - } - - // Initialize the internal state - BatchText = batchText; - StartLine = startLine - 1; // -1 to make sure that the line number of the batch is 0-indexed, since SqlParser gives 1-indexed line numbers - HasExecuted = false; - resultSets = new List(); - resultMessages = new List(); - } + #region Public Methods /// /// Executes this batch and captures any server messages that are returned. @@ -148,23 +165,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } // Read until we hit the end of the result set - ResultSet resultSet = new ResultSet(); - while (await reader.ReadAsync(cancellationToken)) - { - resultSet.AddRow(reader); - } - - // Read off the column schema information - if (reader.CanGetColumnSchema()) - { - resultSet.Columns = reader.GetColumnSchema().ToArray(); - } + ResultSet resultSet = new ResultSet(reader, outputFileFactory); + await resultSet.ReadResultToEnd(cancellationToken); // Add the result set to the results of the query resultSets.Add(resultSet); // Add a message for the number of rows the query returned - resultMessages.Add(string.Format(RowsAffectedFormat, resultSet.Rows.Count)); + resultMessages.Add(string.Format(RowsAffectedFormat, resultSet.RowCount)); } while (await reader.NextResultAsync(cancellationToken)); } } @@ -200,7 +208,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// The starting row of the results /// How many rows to retrieve /// A subset of results - public ResultSetSubset GetSubset(int resultSetIndex, int startRow, int rowCount) + public Task GetSubset(int resultSetIndex, int startRow, int rowCount) { // Sanity check to make sure we have valid numbers if (resultSetIndex < 0 || resultSetIndex >= resultSets.Count) @@ -213,6 +221,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return resultSets[resultSetIndex].GetSubset(startRow, rowCount); } + #endregion + #region Private Helpers /// @@ -259,5 +269,33 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } #endregion + + #region IDisposable Implementation + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + foreach (ResultSet r in ResultSets) + { + r.Dispose(); + } + } + + disposed = true; + } + + #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs new file mode 100644 index 00000000..e80eada5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs @@ -0,0 +1,226 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Data.SqlTypes; +using System.Diagnostics; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Wrapper around a DbColumn, which provides extra functionality, but can be used as a + /// regular DbColumn + /// + public class DbColumnWrapper : DbColumn + { + /// + /// All types supported by the server, stored as a hash set to provide O(1) lookup + /// + private static readonly HashSet AllServerDataTypes = new HashSet + { + "bigint", + "binary", + "bit", + "char", + "datetime", + "decimal", + "float", + "image", + "int", + "money", + "nchar", + "ntext", + "nvarchar", + "real", + "uniqueidentifier", + "smalldatetime", + "smallint", + "smallmoney", + "text", + "timestamp", + "tinyint", + "varbinary", + "varchar", + "sql_variant", + "xml", + "date", + "time", + "datetimeoffset", + "datetime2" + }; + + private readonly DbColumn internalColumn; + + /// + /// Constructor for a DbColumnWrapper + /// + /// Most of this logic is taken from SSMS ColumnInfo class + /// The column we're wrapping around + public DbColumnWrapper(DbColumn column) + { + internalColumn = column; + + switch (column.DataTypeName) + { + case "varchar": + case "nvarchar": + IsChars = true; + + Debug.Assert(column.ColumnSize.HasValue); + if (column.ColumnSize.Value == int.MaxValue) + { + //For Yukon, special case nvarchar(max) with column name == "Microsoft SQL Server 2005 XML Showplan" - + //assume it is an XML showplan. + //Please note this field must be in sync with a similar field defined in QESQLBatch.cs. + //This is not the best fix that we could do but we are trying to minimize code impact + //at this point. Post Yukon we should review this code again and avoid + //hard-coding special column name in multiple places. + const string YukonXmlShowPlanColumn = "Microsoft SQL Server 2005 XML Showplan"; + if (column.ColumnName == YukonXmlShowPlanColumn) + { + // Indicate that this is xml to apply the right size limit + // Note we leave chars type as well to use the right retrieval mechanism. + IsXml = true; + } + IsLong = true; + } + break; + case "text": + case "ntext": + IsChars = true; + IsLong = true; + break; + case "xml": + IsXml = true; + IsLong = true; + break; + case "binary": + case "image": + IsBytes = true; + IsLong = true; + break; + case "varbinary": + case "rowversion": + IsBytes = true; + + Debug.Assert(column.ColumnSize.HasValue); + if (column.ColumnSize.Value == int.MaxValue) + { + IsLong = true; + } + break; + case "sql_variant": + IsSqlVariant = true; + break; + default: + if (!AllServerDataTypes.Contains(column.DataTypeName)) + { + // treat all UDT's as long/bytes data types to prevent the CLR from attempting + // to load the UDT assembly into our process to call ToString() on the object. + + IsUdt = true; + IsBytes = true; + IsLong = true; + } + break; + } + + + if (IsUdt) + { + // udtassemblyqualifiedname property is used to find if the datatype is of hierarchyid assembly type + // Internally hiearchyid is sqlbinary so providerspecific type and type is changed to sqlbinarytype + object assemblyQualifiedName = internalColumn.UdtAssemblyQualifiedName; + const string hierarchyId = "MICROSOFT.SQLSERVER.TYPES.SQLHIERARCHYID"; + + if (assemblyQualifiedName != null && + string.Equals(assemblyQualifiedName.ToString(), hierarchyId, StringComparison.OrdinalIgnoreCase)) + { + DataType = typeof(SqlBinary); + } + else + { + DataType = typeof(byte[]); + } + } + else + { + DataType = DataType; + } + } + + #region Properties + + /// + /// Whether or not the column is bytes + /// + public bool IsBytes { get; private set; } + + /// + /// Whether or not the column is a character type + /// + public bool IsChars { get; private set; } + + /// + /// Whether or not the column is a long type (eg, varchar(MAX)) + /// + public new bool IsLong { get; private set; } + + /// + /// Whether or not the column is a SqlVariant type + /// + public bool IsSqlVariant { get; private set; } + + /// + /// Whether or not the column is a user-defined type + /// + public bool IsUdt { get; private set; } + + /// + /// Whether or not the column is XML + /// + public bool IsXml { get; private set; } + + #endregion + + #region DbColumn Fields + + /// + /// Override for column name, if null or empty, we default to a "no column name" value + /// + public new string ColumnName + { + get + { + // TODO: Localize + return string.IsNullOrEmpty(internalColumn.ColumnName) ? "(No column name)" : internalColumn.ColumnName; + } + } + + public new bool? AllowDBNull { get { return internalColumn.AllowDBNull; } } + public new string BaseCatalogName { get { return internalColumn.BaseCatalogName; } } + public new string BaseColumnName { get { return internalColumn.BaseColumnName; } } + public new string BaseServerName { get { return internalColumn.BaseServerName; } } + public new string BaseTableName { get { return internalColumn.BaseTableName; } } + public new int? ColumnOrdinal { get { return internalColumn.ColumnOrdinal; } } + public new int? ColumnSize { get { return internalColumn.ColumnSize; } } + public new bool? IsAliased { get { return internalColumn.IsAliased; } } + public new bool? IsAutoIncrement { get { return internalColumn.IsAutoIncrement; } } + public new bool? IsExpression { get { return internalColumn.IsExpression; } } + public new bool? IsHidden { get { return internalColumn.IsHidden; } } + public new bool? IsIdentity { get { return internalColumn.IsIdentity; } } + public new bool? IsKey { get { return internalColumn.IsKey; } } + public new bool? IsReadOnly { get { return internalColumn.IsReadOnly; } } + public new bool? IsUnique { get { return internalColumn.IsUnique; } } + public new int? NumericPrecision { get { return internalColumn.NumericPrecision; } } + public new int? NumericScale { get { return internalColumn.NumericScale; } } + public new string UdtAssemblyQualifiedName { get { return internalColumn.UdtAssemblyQualifiedName; } } + public new Type DataType { get; private set; } + public new string DataTypeName { get { return internalColumn.DataTypeName; } } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs index b0a6d75c..c8705d8b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs @@ -3,8 +3,6 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System.Data.Common; - namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts { /// @@ -20,11 +18,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts /// /// The number of rows that was returned with the resultset /// - public int RowCount { get; set; } + public long RowCount { get; set; } /// /// Details about the columns that are provided as solutions /// - public DbColumn[] ColumnInfo { get; set; } + public DbColumnWrapper[] ColumnInfo { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs new file mode 100644 index 00000000..61ee62e0 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs @@ -0,0 +1,50 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Represents a value returned from a read from a file stream. This is used to eliminate ref + /// parameters used in the read methods. + /// + /// The type of the value that was read + public struct FileStreamReadResult + { + /// + /// Whether or not the value of the field is null + /// + public bool IsNull { get; set; } + + /// + /// The value of the field. If is true, this will be set to default(T) + /// + public T Value { get; set; } + + /// + /// The total length in bytes of the value, (including the bytes used to store the length + /// of the value) + /// + /// + /// Cell values are stored such that the length of the value is stored first, then the + /// value itself is stored. Eg, a string may be stored as 0x03 0x6C 0x6F 0x6C. Under this + /// system, the value would be "lol", the length would be 3, and the total length would be + /// 4 bytes. + /// + public int TotalLength { get; set; } + + /// + /// Constructs a new FileStreamReadResult + /// + /// The value of the result + /// The number of bytes for the used to store the value's length and value + /// Whether or not the value is null + public FileStreamReadResult(T value, int totalLength, bool isNull) + { + Value = value; + TotalLength = totalLength; + IsNull = isNull; + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs new file mode 100644 index 00000000..afe616f3 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs @@ -0,0 +1,282 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Diagnostics; +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Wrapper for a file stream, providing simplified creation, deletion, read, and write + /// functionality. + /// + public class FileStreamWrapper : IFileStreamWrapper + { + #region Member Variables + + private byte[] buffer; + private int bufferDataSize; + private FileStream fileStream; + private long startOffset; + private long currentOffset; + + #endregion + + /// + /// Constructs a new FileStreamWrapper and initializes its state. + /// + public FileStreamWrapper() + { + // Initialize the internal state + bufferDataSize = 0; + startOffset = 0; + currentOffset = 0; + } + + #region IFileStreamWrapper Implementation + + /// + /// Initializes the wrapper by creating the internal buffer and opening the requested file. + /// If the file does not already exist, it will be created. + /// + /// Name of the file to open/create + /// The length of the internal buffer + /// + /// Whether or not the wrapper will be used for reading. If true, any calls to a + /// method that writes will cause an InvalidOperationException + /// + public void Init(string fileName, int bufferLength, FileAccess accessMethod) + { + // Sanity check for valid buffer length, fileName, and accessMethod + if (bufferLength <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferLength), "Buffer length must be a positive value"); + } + if (string.IsNullOrWhiteSpace(fileName)) + { + throw new ArgumentNullException(nameof(fileName), "File name cannot be null or whitespace"); + } + if (accessMethod == FileAccess.Write) + { + throw new ArgumentException("Access method cannot be write-only", nameof(fileName)); + } + + // Setup the buffer + buffer = new byte[bufferLength]; + + // Open the requested file for reading/writing, creating one if it doesn't exist + fileStream = new FileStream(fileName, FileMode.OpenOrCreate, accessMethod, FileShare.ReadWrite, + bufferLength, false /*don't use asyncio*/); + + // make file hidden + FileInfo fileInfo = new FileInfo(fileName); + fileInfo.Attributes |= FileAttributes.Hidden; + } + + /// + /// Reads data into a buffer from the current offset into the file + /// + /// The buffer to output the read data to + /// The number of bytes to read into the buffer + /// The number of bytes read + public int ReadData(byte[] buf, int bytes) + { + return ReadData(buf, bytes, currentOffset); + } + + /// + /// Reads data into a buffer from the specified offset into the file + /// + /// The buffer to output the read data to + /// The number of bytes to read into the buffer + /// The offset into the file to start reading bytes from + /// The number of bytes read + public int ReadData(byte[] buf, int bytes, long offset) + { + // Make sure that we're initialized before performing operations + if (buffer == null) + { + throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations"); + } + + MoveTo(offset); + + int bytesCopied = 0; + while (bytesCopied < bytes) + { + int bufferOffset, bytesToCopy; + GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy); + Buffer.BlockCopy(buffer, bufferOffset, buf, bytesCopied, bytesToCopy); + bytesCopied += bytesToCopy; + + if (bytesCopied < bytes && // did not get all the bytes yet + bufferDataSize == buffer.Length) // since current data buffer is full we should continue reading the file + { + // move forward one full length of the buffer + MoveTo(startOffset + buffer.Length); + } + else + { + // copied all the bytes requested or possible, adjust the current buffer pointer + currentOffset += bytesToCopy; + break; + } + } + return bytesCopied; + } + + /// + /// Writes data to the underlying filestream, with buffering. + /// + /// The buffer of bytes to write to the filestream + /// The number of bytes to write + /// The number of bytes written + public int WriteData(byte[] buf, int bytes) + { + // Make sure that we're initialized before performing operations + if (buffer == null) + { + throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations"); + } + if (!fileStream.CanWrite) + { + throw new InvalidOperationException("This FileStreamWrapper canot be used for writing"); + } + + int bytesCopied = 0; + while (bytesCopied < bytes) + { + int bufferOffset, bytesToCopy; + GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy); + Buffer.BlockCopy(buf, bytesCopied, buffer, bufferOffset, bytesToCopy); + bytesCopied += bytesToCopy; + + // adjust the current buffer pointer + currentOffset += bytesToCopy; + + if (bytesCopied < bytes) // did not get all the bytes yet + { + Debug.Assert((int)(currentOffset - startOffset) == buffer.Length); + // flush buffer + Flush(); + } + } + Debug.Assert(bytesCopied == bytes); + return bytesCopied; + } + + /// + /// Flushes the internal buffer to the filestream + /// + public void Flush() + { + // Make sure that we're initialized before performing operations + if (buffer == null) + { + throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations"); + } + if (!fileStream.CanWrite) + { + throw new InvalidOperationException("This FileStreamWrapper cannot be used for writing"); + } + + // Make sure we are at the right place in the file + Debug.Assert(fileStream.Position == startOffset); + + int bytesToWrite = (int)(currentOffset - startOffset); + fileStream.Write(buffer, 0, bytesToWrite); + startOffset += bytesToWrite; + fileStream.Flush(); + + Debug.Assert(startOffset == currentOffset); + } + + /// + /// Deletes the given file (ideally, created with this wrapper) from the filesystem + /// + /// The path to the file to delete + public static void DeleteFile(string fileName) + { + File.Delete(fileName); + } + + #endregion + + /// + /// Perform calculations to determine how many bytes to copy and what the new buffer offset + /// will be for copying. + /// + /// Number of bytes requested to copy + /// Number of bytes copied so far + /// New offset to start copying from/to + /// Number of bytes to copy in this iteration + private void GetByteCounts(int bytes, int bytesCopied, out int bufferOffset, out int bytesToCopy) + { + bufferOffset = (int) (currentOffset - startOffset); + bytesToCopy = bytes - bytesCopied; + if (bytesToCopy > buffer.Length - bufferOffset) + { + bytesToCopy = buffer.Length - bufferOffset; + } + } + + /// + /// Moves the internal buffer to the specified offset into the file + /// + /// Offset into the file to move to + private void MoveTo(long offset) + { + if (buffer.Length > bufferDataSize || // buffer is not completely filled + offset < startOffset || // before current buffer start + offset >= (startOffset + buffer.Length)) // beyond current buffer end + { + // init the offset + startOffset = offset; + + // position file pointer + fileStream.Seek(startOffset, SeekOrigin.Begin); + + // fill in the buffer + bufferDataSize = fileStream.Read(buffer, 0, buffer.Length); + } + // make sure to record where we are + currentOffset = offset; + } + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing && fileStream != null) + { + if(fileStream.CanWrite) { Flush(); } + fileStream.Dispose(); + } + + disposed = true; + } + + ~FileStreamWrapper() + { + Dispose(false); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs new file mode 100644 index 00000000..6cb50095 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs @@ -0,0 +1,22 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a factory that creates filesystem readers/writers + /// + public interface IFileStreamFactory + { + string CreateFile(); + + IFileStreamReader GetReader(string fileName); + + IFileStreamWriter GetWriter(string fileName, int maxCharsToStore, int maxXmlCharsToStore); + + void DisposeFile(string fileName); + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs new file mode 100644 index 00000000..ea5584f1 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs @@ -0,0 +1,35 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a object that reads from the filesystem + /// + public interface IFileStreamReader : IDisposable + { + object[] ReadRow(long offset, IEnumerable columns); + FileStreamReadResult ReadInt16(long i64Offset); + FileStreamReadResult ReadInt32(long i64Offset); + FileStreamReadResult ReadInt64(long i64Offset); + FileStreamReadResult ReadByte(long i64Offset); + FileStreamReadResult ReadChar(long i64Offset); + FileStreamReadResult ReadBoolean(long i64Offset); + FileStreamReadResult ReadSingle(long i64Offset); + FileStreamReadResult ReadDouble(long i64Offset); + FileStreamReadResult ReadSqlDecimal(long i64Offset); + FileStreamReadResult ReadDecimal(long i64Offset); + FileStreamReadResult ReadDateTime(long i64Offset); + FileStreamReadResult ReadTimeSpan(long i64Offset); + FileStreamReadResult ReadString(long i64Offset); + FileStreamReadResult ReadBytes(long i64Offset); + FileStreamReadResult ReadDateTimeOffset(long i64Offset); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs new file mode 100644 index 00000000..38c283c5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs @@ -0,0 +1,22 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a wrapper around a filesystem reader/writer, mainly for unit testing purposes + /// + public interface IFileStreamWrapper : IDisposable + { + void Init(string fileName, int bufferSize, FileAccess fileAccessMode); + int ReadData(byte[] buffer, int bytes); + int ReadData(byte[] buffer, int bytes, long fileOffset); + int WriteData(byte[] buffer, int bytes); + void Flush(); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs new file mode 100644 index 00000000..968701ed --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs @@ -0,0 +1,35 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.SqlTypes; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Interface for a object that writes to a filesystem wrapper + /// + public interface IFileStreamWriter : IDisposable + { + int WriteRow(StorageDataReader dataReader); + int WriteNull(); + int WriteInt16(short val); + int WriteInt32(int val); + int WriteInt64(long val); + int WriteByte(byte val); + int WriteChar(char val); + int WriteBoolean(bool val); + int WriteSingle(float val); + int WriteDouble(double val); + int WriteDecimal(decimal val); + int WriteSqlDecimal(SqlDecimal val); + int WriteDateTime(DateTime val); + int WriteDateTimeOffset(DateTimeOffset dtoVal); + int WriteTimeSpan(TimeSpan val); + int WriteString(string val); + int WriteBytes(byte[] bytes, int length); + void FlushBuffer(); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs new file mode 100644 index 00000000..c06a13ac --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs @@ -0,0 +1,64 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Factory that creates file reader/writers that process rows in an internal, non-human readable file format + /// + public class ServiceBufferFileStreamFactory : IFileStreamFactory + { + /// + /// Creates a new temporary file + /// + /// The name of the temporary file + public string CreateFile() + { + return Path.GetTempFileName(); + } + + /// + /// Creates a new for reading values back from + /// an SSMS formatted buffer file + /// + /// The file to read values from + /// A + public IFileStreamReader GetReader(string fileName) + { + return new ServiceBufferFileStreamReader(new FileStreamWrapper(), fileName); + } + + /// + /// Creates a new for writing values out to an + /// SSMS formatted buffer file + /// + /// The file to write values to + /// The maximum number of characters to store from long text fields + /// The maximum number of characters to store from xml fields + /// A + public IFileStreamWriter GetWriter(string fileName, int maxCharsToStore, int maxXmlCharsToStore) + { + return new ServiceBufferFileStreamWriter(new FileStreamWrapper(), fileName, maxCharsToStore, maxXmlCharsToStore); + } + + /// + /// Disposes of a file created via this factory + /// + /// The file to dispose of + public void DisposeFile(string fileName) + { + try + { + FileStreamWrapper.DeleteFile(fileName); + } + catch + { + // If we have problems deleting the file from a temp location, we don't really care + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs new file mode 100644 index 00000000..0cfc2466 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs @@ -0,0 +1,889 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.IO; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Reader for SSMS formatted file streams + /// + public class ServiceBufferFileStreamReader : IFileStreamReader + { + // Most of this code is based on code from the Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage + // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\FileStreamReader.cs + + private const int DefaultBufferSize = 8192; + + #region Member Variables + + private byte[] buffer; + + private readonly IFileStreamWrapper fileStream; + + #endregion + + /// + /// Constructs a new ServiceBufferFileStreamReader and initializes its state + /// + /// The filestream wrapper to read from + /// The name of the file to read from + public ServiceBufferFileStreamReader(IFileStreamWrapper fileWrapper, string fileName) + { + // Open file for reading/writing + fileStream = fileWrapper; + fileStream.Init(fileName, DefaultBufferSize, FileAccess.Read); + + // Create internal buffer + buffer = new byte[DefaultBufferSize]; + } + + #region IFileStreamStorage Implementation + + /// + /// Reads a row from the file, based on the columns provided + /// + /// Offset into the file where the row starts + /// The columns that were encoded + /// The objects from the row + public object[] ReadRow(long fileOffset, IEnumerable columns) + { + // Initialize for the loop + long currentFileOffset = fileOffset; + List results = new List(); + + // Iterate over the columns + foreach (DbColumnWrapper column in columns) + { + // We will pivot based on the type of the column + Type colType; + if (column.IsSqlVariant) + { + // For SQL Variant columns, the type is written first in string format + FileStreamReadResult sqlVariantTypeResult = ReadString(currentFileOffset); + currentFileOffset += sqlVariantTypeResult.TotalLength; + + // If the typename is null, then the whole value is null + if (sqlVariantTypeResult.IsNull) + { + results.Add(null); + continue; + } + + // The typename is stored in the string + colType = Type.GetType(sqlVariantTypeResult.Value); + + // Workaround .NET bug, see sqlbu# 440643 and vswhidbey# 599834 + // TODO: Is this workaround necessary for .NET Core? + if (colType == null && sqlVariantTypeResult.Value == "System.Data.SqlTypes.SqlSingle") + { + colType = typeof(SqlSingle); + } + } + else + { + colType = column.DataType; + } + + if (colType == typeof(string)) + { + // String - most frequently used data type + FileStreamReadResult result = ReadString(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : result.Value); + } + else if (colType == typeof(SqlString)) + { + // SqlString + FileStreamReadResult result = ReadString(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : (SqlString) result.Value); + } + else if (colType == typeof(short)) + { + // Int16 + FileStreamReadResult result = ReadInt16(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlInt16)) + { + // SqlInt16 + FileStreamReadResult result = ReadInt16(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlInt16)result.Value); + } + } + else if (colType == typeof(int)) + { + // Int32 + FileStreamReadResult result = ReadInt32(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlInt32)) + { + // SqlInt32 + FileStreamReadResult result = ReadInt32(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlInt32)result.Value); + } + } + else if (colType == typeof(long)) + { + // Int64 + FileStreamReadResult result = ReadInt64(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlInt64)) + { + // SqlInt64 + FileStreamReadResult result = ReadInt64(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlInt64)result.Value); + } + } + else if (colType == typeof(byte)) + { + // byte + FileStreamReadResult result = ReadByte(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlByte)) + { + // SqlByte + FileStreamReadResult result = ReadByte(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlByte)result.Value); + } + } + else if (colType == typeof(char)) + { + // Char + FileStreamReadResult result = ReadChar(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(bool)) + { + // Bool + FileStreamReadResult result = ReadBoolean(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlBoolean)) + { + // SqlBoolean + FileStreamReadResult result = ReadBoolean(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlBoolean)result.Value); + } + } + else if (colType == typeof(double)) + { + // double + FileStreamReadResult result = ReadDouble(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlDouble)) + { + // SqlByte + FileStreamReadResult result = ReadDouble(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlDouble)result.Value); + } + } + else if (colType == typeof(float)) + { + // float + FileStreamReadResult result = ReadSingle(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlSingle)) + { + // SqlSingle + FileStreamReadResult result = ReadSingle(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlSingle)result.Value); + } + } + else if (colType == typeof(decimal)) + { + // Decimal + FileStreamReadResult result = ReadDecimal(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlDecimal)) + { + // SqlDecimal + FileStreamReadResult result = ReadSqlDecimal(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(DateTime)) + { + // DateTime + FileStreamReadResult result = ReadDateTime(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlDateTime)) + { + // SqlDateTime + FileStreamReadResult result = ReadDateTime(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add((SqlDateTime)result.Value); + } + } + else if (colType == typeof(DateTimeOffset)) + { + // DateTimeOffset + FileStreamReadResult result = ReadDateTimeOffset(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(TimeSpan)) + { + // TimeSpan + FileStreamReadResult result = ReadTimeSpan(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(byte[])) + { + // Byte Array + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull || (column.IsUdt && result.Value.Length == 0)) + { + results.Add(null); + } + else + { + results.Add(result.Value); + } + } + else if (colType == typeof(SqlBytes)) + { + // SqlBytes + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : new SqlBytes(result.Value)); + } + else if (colType == typeof(SqlBinary)) + { + // SqlBinary + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : new SqlBinary(result.Value)); + } + else if (colType == typeof(SqlGuid)) + { + // SqlGuid + FileStreamReadResult result = ReadBytes(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(new SqlGuid(result.Value)); + } + } + else if (colType == typeof(SqlMoney)) + { + // SqlMoney + FileStreamReadResult result = ReadDecimal(currentFileOffset); + currentFileOffset += result.TotalLength; + if (result.IsNull) + { + results.Add(null); + } + else + { + results.Add(new SqlMoney(result.Value)); + } + } + else + { + // Treat everything else as a string + FileStreamReadResult result = ReadString(currentFileOffset); + currentFileOffset += result.TotalLength; + results.Add(result.IsNull ? null : result.Value); + } + } + + return results.ToArray(); + } + + /// + /// Reads a short from the file at the offset provided + /// + /// Offset into the file to read the short from + /// A short + public FileStreamReadResult ReadInt16(long fileOffset) + { + + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 2, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + short val = default(short); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToInt16(buffer, 0); + } + + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a int from the file at the offset provided + /// + /// Offset into the file to read the int from + /// An int + public FileStreamReadResult ReadInt32(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 4, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + int val = default(int); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToInt32(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a long from the file at the offset provided + /// + /// Offset into the file to read the long from + /// A long + public FileStreamReadResult ReadInt64(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 8, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + long val = default(long); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToInt64(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a byte from the file at the offset provided + /// + /// Offset into the file to read the byte from + /// A byte + public FileStreamReadResult ReadByte(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 1, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + byte val = default(byte); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = buffer[0]; + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a char from the file at the offset provided + /// + /// Offset into the file to read the char from + /// A char + public FileStreamReadResult ReadChar(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 2, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + char val = default(char); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToChar(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a bool from the file at the offset provided + /// + /// Offset into the file to read the bool from + /// A bool + public FileStreamReadResult ReadBoolean(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 1, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + bool val = default(bool); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = buffer[0] == 0x01; + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a single from the file at the offset provided + /// + /// Offset into the file to read the single from + /// A single + public FileStreamReadResult ReadSingle(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 4, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + float val = default(float); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToSingle(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a double from the file at the offset provided + /// + /// Offset into the file to read the double from + /// A double + public FileStreamReadResult ReadDouble(long fileOffset) + { + LengthResult length = ReadLength(fileOffset); + Debug.Assert(length.ValueLength == 0 || length.ValueLength == 8, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + double val = default(double); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + val = BitConverter.ToDouble(buffer, 0); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a SqlDecimal from the file at the offset provided + /// + /// Offset into the file to read the SqlDecimal from + /// A SqlDecimal + public FileStreamReadResult ReadSqlDecimal(long offset) + { + LengthResult length = ReadLength(offset); + Debug.Assert(length.ValueLength == 0 || (length.ValueLength - 3)%4 == 0, + string.Format("Invalid data length: {0}", length.ValueLength)); + + bool isNull = length.ValueLength == 0; + SqlDecimal val = default(SqlDecimal); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + + int[] arrInt32 = new int[(length.ValueLength - 3)/4]; + Buffer.BlockCopy(buffer, 3, arrInt32, 0, length.ValueLength - 3); + val = new SqlDecimal(buffer[0], buffer[1], 1 == buffer[2], arrInt32); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a decimal from the file at the offset provided + /// + /// Offset into the file to read the decimal from + /// A decimal + public FileStreamReadResult ReadDecimal(long offset) + { + LengthResult length = ReadLength(offset); + Debug.Assert(length.ValueLength%4 == 0, "Invalid data length"); + + bool isNull = length.ValueLength == 0; + decimal val = default(decimal); + if (!isNull) + { + fileStream.ReadData(buffer, length.ValueLength); + + int[] arrInt32 = new int[length.ValueLength/4]; + Buffer.BlockCopy(buffer, 0, arrInt32, 0, length.ValueLength); + val = new decimal(arrInt32); + } + return new FileStreamReadResult(val, length.TotalLength, isNull); + } + + /// + /// Reads a DateTime from the file at the offset provided + /// + /// Offset into the file to read the DateTime from + /// A DateTime + public FileStreamReadResult ReadDateTime(long offset) + { + FileStreamReadResult ticks = ReadInt64(offset); + DateTime val = default(DateTime); + if (!ticks.IsNull) + { + val = new DateTime(ticks.Value); + } + return new FileStreamReadResult(val, ticks.TotalLength, ticks.IsNull); + } + + /// + /// Reads a DateTimeOffset from the file at the offset provided + /// + /// Offset into the file to read the DateTimeOffset from + /// A DateTimeOffset + public FileStreamReadResult ReadDateTimeOffset(long offset) + { + // DateTimeOffset is represented by DateTime.Ticks followed by TimeSpan.Ticks + // both as Int64 values + + // read the DateTime ticks + DateTimeOffset val = default(DateTimeOffset); + FileStreamReadResult dateTimeTicks = ReadInt64(offset); + int totalLength = dateTimeTicks.TotalLength; + if (dateTimeTicks.TotalLength > 0 && !dateTimeTicks.IsNull) + { + // read the TimeSpan ticks + FileStreamReadResult timeSpanTicks = ReadInt64(offset + dateTimeTicks.TotalLength); + Debug.Assert(!timeSpanTicks.IsNull, "TimeSpan ticks cannot be null if DateTime ticks are not null!"); + + totalLength += timeSpanTicks.TotalLength; + + // build the DateTimeOffset + val = new DateTimeOffset(new DateTime(dateTimeTicks.Value), new TimeSpan(timeSpanTicks.Value)); + } + return new FileStreamReadResult(val, totalLength, dateTimeTicks.IsNull); + } + + /// + /// Reads a TimeSpan from the file at the offset provided + /// + /// Offset into the file to read the TimeSpan from + /// A TimeSpan + public FileStreamReadResult ReadTimeSpan(long offset) + { + FileStreamReadResult timeSpanTicks = ReadInt64(offset); + TimeSpan val = default(TimeSpan); + if (!timeSpanTicks.IsNull) + { + val = new TimeSpan(timeSpanTicks.Value); + } + return new FileStreamReadResult(val, timeSpanTicks.TotalLength, timeSpanTicks.IsNull); + } + + /// + /// Reads a string from the file at the offset provided + /// + /// Offset into the file to read the string from + /// A string + public FileStreamReadResult ReadString(long offset) + { + LengthResult fieldLength = ReadLength(offset); + Debug.Assert(fieldLength.ValueLength%2 == 0, "Invalid data length"); + + if (fieldLength.ValueLength == 0) // there is no data + { + // If the total length is 5 (5 bytes for length, 0 for value), then the string is empty + // Otherwise, the string is null + bool isNull = fieldLength.TotalLength != 5; + return new FileStreamReadResult(isNull ? null : string.Empty, + fieldLength.TotalLength, isNull); + } + + // positive length + AssureBufferLength(fieldLength.ValueLength); + fileStream.ReadData(buffer, fieldLength.ValueLength); + return new FileStreamReadResult(Encoding.Unicode.GetString(buffer, 0, fieldLength.ValueLength), fieldLength.TotalLength, false); + } + + /// + /// Reads bytes from the file at the offset provided + /// + /// Offset into the file to read the bytes from + /// A byte array + public FileStreamReadResult ReadBytes(long offset) + { + LengthResult fieldLength = ReadLength(offset); + + if (fieldLength.ValueLength == 0) + { + // If the total length is 5 (5 bytes for length, 0 for value), then the byte array is 0x + // Otherwise, the byte array is null + bool isNull = fieldLength.TotalLength != 5; + return new FileStreamReadResult(isNull ? null : new byte[0], + fieldLength.TotalLength, isNull); + } + + // positive length + byte[] val = new byte[fieldLength.ValueLength]; + fileStream.ReadData(val, fieldLength.ValueLength); + return new FileStreamReadResult(val, fieldLength.TotalLength, false); + } + + /// + /// Reads the length of a field at the specified offset in the file + /// + /// Offset into the file to read the field length from + /// A LengthResult + internal LengthResult ReadLength(long offset) + { + // read in length information + int lengthValue; + int lengthLength = fileStream.ReadData(buffer, 1, offset); + if (buffer[0] != 0xFF) + { + // one byte is enough + lengthValue = Convert.ToInt32(buffer[0]); + } + else + { + // read in next 4 bytes + lengthLength += fileStream.ReadData(buffer, 4); + + // reconstruct the length + lengthValue = BitConverter.ToInt32(buffer, 0); + } + + return new LengthResult {LengthLength = lengthLength, ValueLength = lengthValue}; + } + + #endregion + + /// + /// Internal struct used for representing the length of a field from the file + /// + internal struct LengthResult + { + /// + /// How many bytes the length takes up + /// + public int LengthLength { get; set; } + + /// + /// How many bytes the value takes up + /// + public int ValueLength { get; set; } + + /// + /// + + /// + public int TotalLength + { + get { return LengthLength + ValueLength; } + } + } + + /// + /// Creates a new buffer that is of the specified length if the buffer is not already + /// at least as long as specified. + /// + /// The minimum buffer size + private void AssureBufferLength(int newBufferLength) + { + if (buffer.Length < newBufferLength) + { + buffer = new byte[newBufferLength]; + } + } + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + fileStream.Dispose(); + } + + disposed = true; + } + + ~ServiceBufferFileStreamReader() + { + Dispose(false); + } + + #endregion + + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs new file mode 100644 index 00000000..d0a1c2a9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamWriter.cs @@ -0,0 +1,749 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Writer for SSMS formatted file streams + /// + public class ServiceBufferFileStreamWriter : IFileStreamWriter + { + // Most of this code is based on code from the Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage + // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\FileStreamWriter.cs + + #region Properties + + public const int DefaultBufferLength = 8192; + + private int MaxCharsToStore { get; set; } + private int MaxXmlCharsToStore { get; set; } + + private IFileStreamWrapper FileStream { get; set; } + private byte[] byteBuffer; + private readonly short[] shortBuffer; + private readonly int[] intBuffer; + private readonly long[] longBuffer; + private readonly char[] charBuffer; + private readonly double[] doubleBuffer; + private readonly float[] floatBuffer; + + #endregion + + /// + /// Constructs a new writer + /// + /// The file wrapper to use as the underlying file stream + /// Name of the file to write to + /// Maximum number of characters to store for long text fields + /// Maximum number of characters to store for XML fields + public ServiceBufferFileStreamWriter(IFileStreamWrapper fileWrapper, string fileName, int maxCharsToStore, int maxXmlCharsToStore) + { + // open file for reading/writing + FileStream = fileWrapper; + FileStream.Init(fileName, DefaultBufferLength, FileAccess.ReadWrite); + + // create internal buffer + byteBuffer = new byte[DefaultBufferLength]; + + // Create internal buffers for blockcopy of contents to byte array + // Note: We create them now to avoid the overhead of creating a new array for every write call + shortBuffer = new short[1]; + intBuffer = new int[1]; + longBuffer = new long[1]; + charBuffer = new char[1]; + doubleBuffer = new double[1]; + floatBuffer = new float[1]; + + // Store max chars to store + MaxCharsToStore = maxCharsToStore; + MaxXmlCharsToStore = maxXmlCharsToStore; + } + + #region IFileStreamWriter Implementation + + /// + /// Writes an entire row to the file stream + /// + /// A primed reader + /// Number of bytes used to write the row + public int WriteRow(StorageDataReader reader) + { + // Determine if we have any long fields + bool hasLongFields = reader.Columns.Any(column => column.IsLong); + + object[] values = new object[reader.Columns.Length]; + int rowBytes = 0; + if (!hasLongFields) + { + // get all record values in one shot if there are no extra long fields + reader.GetValues(values); + } + + // Loop over all the columns and write the values to the temp file + for (int i = 0; i < reader.Columns.Length; i++) + { + DbColumnWrapper ci = reader.Columns[i]; + if (hasLongFields) + { + if (reader.IsDBNull(i)) + { + // Need special case for DBNull because + // reader.GetValue doesn't return DBNull in case of SqlXml and CLR type + values[i] = DBNull.Value; + } + else + { + if (!ci.IsLong) + { + // not a long field + values[i] = reader.GetValue(i); + } + else + { + // this is a long field + if (ci.IsBytes) + { + values[i] = reader.GetBytesWithMaxCapacity(i, MaxCharsToStore); + } + else if (ci.IsChars) + { + Debug.Assert(MaxCharsToStore > 0); + values[i] = reader.GetCharsWithMaxCapacity(i, + ci.IsXml ? MaxXmlCharsToStore : MaxCharsToStore); + } + else if (ci.IsXml) + { + Debug.Assert(MaxXmlCharsToStore > 0); + values[i] = reader.GetXmlWithMaxCapacity(i, MaxXmlCharsToStore); + } + else + { + // we should never get here + Debug.Assert(false); + } + } + } + } + + Type tVal = values[i].GetType(); // get true type of the object + + if (tVal == typeof(DBNull)) + { + rowBytes += WriteNull(); + } + else + { + if (ci.IsSqlVariant) + { + // serialize type information as a string before the value + string val = tVal.ToString(); + rowBytes += WriteString(val); + } + + if (tVal == typeof(string)) + { + // String - most frequently used data type + string val = (string)values[i]; + rowBytes += WriteString(val); + } + else if (tVal == typeof(SqlString)) + { + // SqlString + SqlString val = (SqlString)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteString(val.Value); + } + } + else if (tVal == typeof(short)) + { + // Int16 + short val = (short)values[i]; + rowBytes += WriteInt16(val); + } + else if (tVal == typeof(SqlInt16)) + { + // SqlInt16 + SqlInt16 val = (SqlInt16)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteInt16(val.Value); + } + } + else if (tVal == typeof(int)) + { + // Int32 + int val = (int)values[i]; + rowBytes += WriteInt32(val); + } + else if (tVal == typeof(SqlInt32)) + { + // SqlInt32 + SqlInt32 val = (SqlInt32)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteInt32(val.Value); + } + } + else if (tVal == typeof(long)) + { + // Int64 + long val = (long)values[i]; + rowBytes += WriteInt64(val); + } + else if (tVal == typeof(SqlInt64)) + { + // SqlInt64 + SqlInt64 val = (SqlInt64)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteInt64(val.Value); + } + } + else if (tVal == typeof(byte)) + { + // Byte + byte val = (byte)values[i]; + rowBytes += WriteByte(val); + } + else if (tVal == typeof(SqlByte)) + { + // SqlByte + SqlByte val = (SqlByte)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteByte(val.Value); + } + } + else if (tVal == typeof(char)) + { + // Char + char val = (char)values[i]; + rowBytes += WriteChar(val); + } + else if (tVal == typeof(bool)) + { + // Boolean + bool val = (bool)values[i]; + rowBytes += WriteBoolean(val); + } + else if (tVal == typeof(SqlBoolean)) + { + // SqlBoolean + SqlBoolean val = (SqlBoolean)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteBoolean(val.Value); + } + } + else if (tVal == typeof(double)) + { + // Double + double val = (double)values[i]; + rowBytes += WriteDouble(val); + } + else if (tVal == typeof(SqlDouble)) + { + // SqlDouble + SqlDouble val = (SqlDouble)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteDouble(val.Value); + } + } + else if (tVal == typeof(SqlSingle)) + { + // SqlSingle + SqlSingle val = (SqlSingle)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteSingle(val.Value); + } + } + else if (tVal == typeof(decimal)) + { + // Decimal + decimal val = (decimal)values[i]; + rowBytes += WriteDecimal(val); + } + else if (tVal == typeof(SqlDecimal)) + { + // SqlDecimal + SqlDecimal val = (SqlDecimal)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteSqlDecimal(val); + } + } + else if (tVal == typeof(DateTime)) + { + // DateTime + DateTime val = (DateTime)values[i]; + rowBytes += WriteDateTime(val); + } + else if (tVal == typeof(DateTimeOffset)) + { + // DateTimeOffset + DateTimeOffset val = (DateTimeOffset)values[i]; + rowBytes += WriteDateTimeOffset(val); + } + else if (tVal == typeof(SqlDateTime)) + { + // SqlDateTime + SqlDateTime val = (SqlDateTime)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteDateTime(val.Value); + } + } + else if (tVal == typeof(TimeSpan)) + { + // TimeSpan + TimeSpan val = (TimeSpan)values[i]; + rowBytes += WriteTimeSpan(val); + } + else if (tVal == typeof(byte[])) + { + // Bytes + byte[] val = (byte[])values[i]; + rowBytes += WriteBytes(val, val.Length); + } + else if (tVal == typeof(SqlBytes)) + { + // SqlBytes + SqlBytes val = (SqlBytes)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteBytes(val.Value, val.Value.Length); + } + } + else if (tVal == typeof(SqlBinary)) + { + // SqlBinary + SqlBinary val = (SqlBinary)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteBytes(val.Value, val.Value.Length); + } + } + else if (tVal == typeof(SqlGuid)) + { + // SqlGuid + SqlGuid val = (SqlGuid)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + byte[] bytesVal = val.ToByteArray(); + rowBytes += WriteBytes(bytesVal, bytesVal.Length); + } + } + else if (tVal == typeof(SqlMoney)) + { + // SqlMoney + SqlMoney val = (SqlMoney)values[i]; + if (val.IsNull) + { + rowBytes += WriteNull(); + } + else + { + rowBytes += WriteDecimal(val.Value); + } + } + else + { + // treat everything else as string + string val = values[i].ToString(); + rowBytes += WriteString(val); + } + } + } + + // Flush the buffer after every row + FlushBuffer(); + return rowBytes; + } + + /// + /// Writes null to the file as one 0x00 byte + /// + /// Number of bytes used to store the null + public int WriteNull() + { + byteBuffer[0] = 0x00; + return FileStream.WriteData(byteBuffer, 1); + } + + /// + /// Writes a short to the file + /// + /// Number of bytes used to store the short + public int WriteInt16(short val) + { + byteBuffer[0] = 0x02; // length + shortBuffer[0] = val; + Buffer.BlockCopy(shortBuffer, 0, byteBuffer, 1, 2); + return FileStream.WriteData(byteBuffer, 3); + } + + /// + /// Writes a int to the file + /// + /// Number of bytes used to store the int + public int WriteInt32(int val) + { + byteBuffer[0] = 0x04; // length + intBuffer[0] = val; + Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); + return FileStream.WriteData(byteBuffer, 5); + } + + /// + /// Writes a long to the file + /// + /// Number of bytes used to store the long + public int WriteInt64(long val) + { + byteBuffer[0] = 0x08; // length + longBuffer[0] = val; + Buffer.BlockCopy(longBuffer, 0, byteBuffer, 1, 8); + return FileStream.WriteData(byteBuffer, 9); + } + + /// + /// Writes a char to the file + /// + /// Number of bytes used to store the char + public int WriteChar(char val) + { + byteBuffer[0] = 0x02; // length + charBuffer[0] = val; + Buffer.BlockCopy(charBuffer, 0, byteBuffer, 1, 2); + return FileStream.WriteData(byteBuffer, 3); + } + + /// + /// Writes a bool to the file + /// + /// Number of bytes used to store the bool + public int WriteBoolean(bool val) + { + byteBuffer[0] = 0x01; // length + byteBuffer[1] = (byte) (val ? 0x01 : 0x00); + return FileStream.WriteData(byteBuffer, 2); + } + + /// + /// Writes a byte to the file + /// + /// Number of bytes used to store the byte + public int WriteByte(byte val) + { + byteBuffer[0] = 0x01; // length + byteBuffer[1] = val; + return FileStream.WriteData(byteBuffer, 2); + } + + /// + /// Writes a float to the file + /// + /// Number of bytes used to store the float + public int WriteSingle(float val) + { + byteBuffer[0] = 0x04; // length + floatBuffer[0] = val; + Buffer.BlockCopy(floatBuffer, 0, byteBuffer, 1, 4); + return FileStream.WriteData(byteBuffer, 5); + } + + /// + /// Writes a double to the file + /// + /// Number of bytes used to store the double + public int WriteDouble(double val) + { + byteBuffer[0] = 0x08; // length + doubleBuffer[0] = val; + Buffer.BlockCopy(doubleBuffer, 0, byteBuffer, 1, 8); + return FileStream.WriteData(byteBuffer, 9); + } + + /// + /// Writes a SqlDecimal to the file + /// + /// Number of bytes used to store the SqlDecimal + public int WriteSqlDecimal(SqlDecimal val) + { + int[] arrInt32 = val.Data; + int iLen = 3 + (arrInt32.Length * 4); + int iTotalLen = WriteLength(iLen); // length + + // precision + byteBuffer[0] = val.Precision; + + // scale + byteBuffer[1] = val.Scale; + + // positive + byteBuffer[2] = (byte)(val.IsPositive ? 0x01 : 0x00); + + // data value + Buffer.BlockCopy(arrInt32, 0, byteBuffer, 3, iLen - 3); + iTotalLen += FileStream.WriteData(byteBuffer, iLen); + return iTotalLen; // len+data + } + + /// + /// Writes a decimal to the file + /// + /// Number of bytes used to store the decimal + public int WriteDecimal(decimal val) + { + int[] arrInt32 = decimal.GetBits(val); + + int iLen = arrInt32.Length * 4; + int iTotalLen = WriteLength(iLen); // length + + Buffer.BlockCopy(arrInt32, 0, byteBuffer, 0, iLen); + iTotalLen += FileStream.WriteData(byteBuffer, iLen); + + return iTotalLen; // len+data + } + + /// + /// Writes a DateTime to the file + /// + /// Number of bytes used to store the DateTime + public int WriteDateTime(DateTime dtVal) + { + return WriteInt64(dtVal.Ticks); + } + + /// + /// Writes a DateTimeOffset to the file + /// + /// Number of bytes used to store the DateTimeOffset + public int WriteDateTimeOffset(DateTimeOffset dtoVal) + { + // DateTimeOffset gets written as a DateTime + TimeOffset + // both represented as 'Ticks' written as Int64's + return WriteInt64(dtoVal.Ticks) + WriteInt64(dtoVal.Offset.Ticks); + } + + /// + /// Writes a TimeSpan to the file + /// + /// Number of bytes used to store the TimeSpan + public int WriteTimeSpan(TimeSpan timeSpan) + { + return WriteInt64(timeSpan.Ticks); + } + + /// + /// Writes a string to the file + /// + /// Number of bytes used to store the string + public int WriteString(string sVal) + { + if (sVal == null) + { + throw new ArgumentNullException(nameof(sVal), "String to store must be non-null."); + } + + int iTotalLen; + if (0 == sVal.Length) // special case of 0 length string + { + const int iLen = 5; + + AssureBufferLength(iLen); + byteBuffer[0] = 0xFF; + byteBuffer[1] = 0x00; + byteBuffer[2] = 0x00; + byteBuffer[3] = 0x00; + byteBuffer[4] = 0x00; + + iTotalLen = FileStream.WriteData(byteBuffer, 5); + } + else + { + // Convert to a unicode byte array + byte[] bytes = Encoding.Unicode.GetBytes(sVal); + + // convert char array into byte array and write it out + iTotalLen = WriteLength(bytes.Length); + iTotalLen += FileStream.WriteData(bytes, bytes.Length); + } + return iTotalLen; // len+data + } + + /// + /// Writes a byte[] to the file + /// + /// Number of bytes used to store the byte[] + public int WriteBytes(byte[] bytesVal, int iLen) + { + if (bytesVal == null) + { + throw new ArgumentNullException(nameof(bytesVal), "Byte array to store must be non-null."); + } + + int iTotalLen; + if (0 == iLen) // special case of 0 length byte array "0x" + { + iLen = 5; + + AssureBufferLength(iLen); + byteBuffer[0] = 0xFF; + byteBuffer[1] = 0x00; + byteBuffer[2] = 0x00; + byteBuffer[3] = 0x00; + byteBuffer[4] = 0x00; + + iTotalLen = FileStream.WriteData(byteBuffer, iLen); + } + else + { + iTotalLen = WriteLength(iLen); + iTotalLen += FileStream.WriteData(bytesVal, iLen); + } + return iTotalLen; // len+data + } + + /// + /// Writes the length of the field using the appropriate number of bytes (ie, 1 if the + /// length is <255, 5 if the length is >=255) + /// + /// Number of bytes used to store the length + private int WriteLength(int iLen) + { + if (iLen < 0xFF) + { + // fits in one byte of memory only need to write one byte + int iTmp = iLen & 0x000000FF; + + byteBuffer[0] = Convert.ToByte(iTmp); + return FileStream.WriteData(byteBuffer, 1); + } + // The length won't fit in 1 byte, so we need to use 1 byte to signify that the length + // is a full 4 bytes. + byteBuffer[0] = 0xFF; + + // convert int32 into array of bytes + intBuffer[0] = iLen; + Buffer.BlockCopy(intBuffer, 0, byteBuffer, 1, 4); + return FileStream.WriteData(byteBuffer, 5); + } + + /// + /// Flushes the internal buffer to the file stream + /// + public void FlushBuffer() + { + FileStream.Flush(); + } + + #endregion + + private void AssureBufferLength(int newBufferLength) + { + if (newBufferLength > byteBuffer.Length) + { + byteBuffer = new byte[byteBuffer.Length]; + } + } + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + FileStream.Flush(); + FileStream.Dispose(); + } + + disposed = true; + } + + ~ServiceBufferFileStreamWriter() + { + Dispose(false); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs new file mode 100644 index 00000000..f63046b1 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/StorageDataReader.cs @@ -0,0 +1,356 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Data.Common; +using System.Data.SqlClient; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using System.Xml; +using Microsoft.SqlTools.EditorServices.Utility; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage +{ + /// + /// Wrapper around a DbData reader to perform some special operations more simply + /// + public class StorageDataReader + { + // This code is based on code from Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage, + // StorageDataReader + // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\StorageDataReader.cs + + #region Member Variables + + /// + /// If the DbDataReader is a SqlDataReader, it will be set here + /// + private readonly SqlDataReader sqlDataReader; + + /// + /// Whether or not the data reader supports SqlXml types + /// + private readonly bool supportSqlXml; + + #endregion + + /// + /// Constructs a new wrapper around the provided reader + /// + /// The reader to wrap around + public StorageDataReader(DbDataReader reader) + { + // Sanity check to make sure there is a data reader + Validate.IsNotNull(nameof(reader), reader); + + // Attempt to use this reader as a SqlDataReader + sqlDataReader = reader as SqlDataReader; + supportSqlXml = sqlDataReader != null; + DbDataReader = reader; + + // Read the columns into a set of wrappers + Columns = DbDataReader.GetColumnSchema().Select(column => new DbColumnWrapper(column)).ToArray(); + } + + #region Properties + + /// + /// All the columns that this reader currently contains + /// + public DbColumnWrapper[] Columns { get; private set; } + + /// + /// The that will be read from + /// + public DbDataReader DbDataReader { get; private set; } + + #endregion + + #region DbDataReader Methods + + /// + /// Pass-through to DbDataReader.ReadAsync() + /// + /// The cancellation token to use for cancelling a query + /// + public Task ReadAsync(CancellationToken cancellationToken) + { + return DbDataReader.ReadAsync(cancellationToken); + } + + /// + /// Retrieves a value + /// + /// Column ordinal + /// The value of the given column + public object GetValue(int i) + { + return sqlDataReader == null ? DbDataReader.GetValue(i) : sqlDataReader.GetValue(i); + } + + /// + /// Stores all values of the current row into the provided object array + /// + /// Where to store the values from this row + public void GetValues(object[] values) + { + if (sqlDataReader == null) + { + DbDataReader.GetValues(values); + } + else + { + sqlDataReader.GetValues(values); + } + } + + /// + /// Whether or not the cell of the given column at the current row is a DBNull + /// + /// Column ordinal + /// True if the cell is DBNull, false otherwise + public bool IsDBNull(int i) + { + return DbDataReader.IsDBNull(i); + } + + #endregion + + #region Public Methods + + /// + /// Retrieves bytes with a maximum number of bytes to return + /// + /// Column ordinal + /// Number of bytes to return at maximum + /// Byte array + public byte[] GetBytesWithMaxCapacity(int iCol, int maxNumBytesToReturn) + { + if (maxNumBytesToReturn <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxNumBytesToReturn), "Maximum number of bytes to return must be greater than zero."); + } + + //first, ask provider how much data it has and calculate the final # of bytes + //NOTE: -1 means that it doesn't know how much data it has + long neededLength; + long origLength = neededLength = GetBytes(iCol, 0, null, 0, 0); + if (neededLength == -1 || neededLength > maxNumBytesToReturn) + { + neededLength = maxNumBytesToReturn; + } + + //get the data up to the maxNumBytesToReturn + byte[] bytesBuffer = new byte[neededLength]; + GetBytes(iCol, 0, bytesBuffer, 0, (int)neededLength); + + //see if server sent back more data than we should return + if (origLength == -1 || origLength > neededLength) + { + //pump the rest of data from the reader and discard it right away + long dataIndex = neededLength; + const int tmpBufSize = 100000; + byte[] tmpBuf = new byte[tmpBufSize]; + while (GetBytes(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize) + { + dataIndex += tmpBufSize; + } + } + + return bytesBuffer; + } + + /// + /// Retrieves characters with a maximum number of charss to return + /// + /// Column ordinal + /// Number of chars to return at maximum + /// String + public string GetCharsWithMaxCapacity(int iCol, int maxCharsToReturn) + { + if (maxCharsToReturn <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxCharsToReturn), "Maximum number of chars to return must be greater than zero"); + } + + //first, ask provider how much data it has and calculate the final # of chars + //NOTE: -1 means that it doesn't know how much data it has + long neededLength; + long origLength = neededLength = GetChars(iCol, 0, null, 0, 0); + if (neededLength == -1 || neededLength > maxCharsToReturn) + { + neededLength = maxCharsToReturn; + } + Debug.Assert(neededLength < int.MaxValue); + + //get the data up to maxCharsToReturn + char[] buffer = new char[neededLength]; + if (neededLength > 0) + { + GetChars(iCol, 0, buffer, 0, (int)neededLength); + } + + //see if server sent back more data than we should return + if (origLength == -1 || origLength > neededLength) + { + //pump the rest of data from the reader and discard it right away + long dataIndex = neededLength; + const int tmpBufSize = 100000; + char[] tmpBuf = new char[tmpBufSize]; + while (GetChars(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize) + { + dataIndex += tmpBufSize; + } + } + string res = new string(buffer); + return res; + } + + /// + /// Retrieves xml with a maximum number of bytes to return + /// + /// Column ordinal + /// Number of chars to return at maximum + /// String + public string GetXmlWithMaxCapacity(int iCol, int maxCharsToReturn) + { + if (supportSqlXml) + { + SqlXml sm = GetSqlXml(iCol); + if (sm == null) + { + return null; + } + + //this code is mostly copied from SqlClient implementation of returning value for XML data type + StringWriterWithMaxCapacity sw = new StringWriterWithMaxCapacity(null, maxCharsToReturn); + XmlWriterSettings writerSettings = new XmlWriterSettings + { + CloseOutput = false, + ConformanceLevel = ConformanceLevel.Fragment + }; + // don't close the memory stream + XmlWriter ww = XmlWriter.Create(sw, writerSettings); + + XmlReader reader = sm.CreateReader(); + reader.Read(); + + while (!reader.EOF) + { + ww.WriteNode(reader, true); + } + ww.Flush(); + return sw.ToString(); + } + + object o = GetValue(iCol); + return o?.ToString(); + } + + #endregion + + #region Private Helpers + + private long GetBytes(int i, long dataIndex, byte[] buffer, int bufferIndex, int length) + { + return DbDataReader.GetBytes(i, dataIndex, buffer, bufferIndex, length); + } + + private long GetChars(int i, long dataIndex, char[] buffer, int bufferIndex, int length) + { + return DbDataReader.GetChars(i, dataIndex, buffer, bufferIndex, length); + } + + private SqlXml GetSqlXml(int i) + { + if (sqlDataReader == null) + { + // We need a Sql data reader in order to retrieve sql xml + throw new InvalidOperationException("Cannot retrieve SqlXml without a SqlDataReader"); + } + + return sqlDataReader.GetSqlXml(i); + } + + #endregion + + /// + /// Internal class for writing strings with a maximum capacity + /// + /// + /// This code is take almost verbatim from Microsoft.SqlServer.Management.UI.Grid, SSMS + /// DataStorage, StorageDataReader class. + /// + private class StringWriterWithMaxCapacity : StringWriter + { + private bool stopWriting; + + private int CurrentLength + { + get { return GetStringBuilder().Length; } + } + + public StringWriterWithMaxCapacity(IFormatProvider formatProvider, int capacity) : base(formatProvider) + { + MaximumCapacity = capacity; + } + + private int MaximumCapacity { get; set; } + + public override void Write(char value) + { + if (stopWriting) { return; } + + if (CurrentLength < MaximumCapacity) + { + base.Write(value); + } + else + { + stopWriting = true; + } + } + + public override void Write(char[] buffer, int index, int count) + { + if (stopWriting) { return; } + + int curLen = CurrentLength; + if (curLen + (count - index) > MaximumCapacity) + { + stopWriting = true; + + count = MaximumCapacity - curLen + index; + if (count < 0) + { + count = 0; + } + } + base.Write(buffer, index, count); + } + + public override void Write(string value) + { + if (stopWriting) { return; } + + int curLen = CurrentLength; + if (value.Length + curLen > MaximumCapacity) + { + stopWriting = true; + base.Write(value.Substring(0, MaximumCapacity - curLen)); + } + else + { + base.Write(value); + } + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index d5ce270d..2da2a15d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -1,7 +1,6 @@ -// +// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -// using System; using System.Data.Common; @@ -12,6 +11,7 @@ using System.Threading.Tasks; using Microsoft.SqlServer.Management.SqlParser.Parser; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; using Microsoft.SqlTools.ServiceLayer.SqlContext; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution @@ -21,15 +21,84 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public class Query : IDisposable { - #region Constants - /// /// "Error" code produced by SQL Server when the database context (name) for a connection changes. /// private const int DatabaseContextChangeErrorNumber = 5701; + #region Member Variables + + /// + /// Cancellation token source, used for cancelling async db actions + /// + private readonly CancellationTokenSource cancellationSource; + + /// + /// For IDisposable implementation, whether or not this object has been disposed + /// + private bool disposed; + + /// + /// The connection info associated with the file editor owner URI, used to create a new + /// connection upon execution of the query + /// + private readonly ConnectionInfo editorConnection; + + /// + /// Whether or not the execute method has been called for this query + /// + private bool hasExecuteBeenCalled; + + /// + /// The factory to use for outputting the results of this query + /// + private readonly IFileStreamFactory outputFileFactory; + #endregion + /// + /// Constructor for a query + /// + /// The text of the query to execute + /// The information of the connection to use to execute the query + /// Settings for how to execute the query, from the user + /// Factory for creating output files + public Query(string queryText, ConnectionInfo connection, QueryExecutionSettings settings, IFileStreamFactory outputFactory) + { + // Sanity check for input + if (string.IsNullOrEmpty(queryText)) + { + throw new ArgumentNullException(nameof(queryText), "Query text cannot be null"); + } + if (connection == null) + { + throw new ArgumentNullException(nameof(connection), "Connection cannot be null"); + } + if (settings == null) + { + throw new ArgumentNullException(nameof(settings), "Settings cannot be null"); + } + if (outputFactory == null) + { + throw new ArgumentNullException(nameof(outputFactory), "Output file factory cannot be null"); + } + + // Initialize the internal state + QueryText = queryText; + editorConnection = connection; + cancellationSource = new CancellationTokenSource(); + outputFileFactory = outputFactory; + + // Process the query into batches + ParseResult parseResult = Parser.Parse(queryText, new ParseOptions + { + BatchSeparator = settings.BatchSeparator + }); + // NOTE: We only want to process batches that have statements (ie, ignore comments and empty lines) + Batches = parseResult.Script.Batches.Where(b => b.Statements.Count > 0) + .Select(b => new Batch(b.Sql, b.StartLocation.LineNumber, outputFileFactory)).ToArray(); + } + #region Properties /// @@ -59,19 +128,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } } - /// - /// Cancellation token source, used for cancelling async db actions - /// - private readonly CancellationTokenSource cancellationSource; - - /// - /// The connection info associated with the file editor owner URI, used to create a new - /// connection upon execution of the query - /// - private ConnectionInfo EditorConnection { get; set; } - - private bool HasExecuteBeenCalled { get; set; } - /// /// Whether or not the query has completed executed, regardless of success or failure /// @@ -80,10 +136,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// public bool HasExecuted { - get { return Batches.Length == 0 ? HasExecuteBeenCalled : Batches.All(b => b.HasExecuted); } + get { return Batches.Length == 0 ? hasExecuteBeenCalled : Batches.All(b => b.HasExecuted); } internal set { - HasExecuteBeenCalled = value; + hasExecuteBeenCalled = value; foreach (var batch in Batches) { batch.HasExecuted = value; @@ -98,41 +154,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #endregion + #region Public Methods + /// - /// Constructor for a query + /// Cancels the query by issuing the cancellation token /// - /// The text of the query to execute - /// The information of the connection to use to execute the query - /// Settings for how to execute the query, from the user - public Query(string queryText, ConnectionInfo connection, QueryExecutionSettings settings) + public void Cancel() { - // Sanity check for input - if (string.IsNullOrEmpty(queryText)) + // Make sure that the query hasn't completed execution + if (HasExecuted) { - throw new ArgumentNullException(nameof(queryText), "Query text cannot be null"); - } - if (connection == null) - { - throw new ArgumentNullException(nameof(connection), "Connection cannot be null"); - } - if (settings == null) - { - throw new ArgumentNullException(nameof(settings), "Settings cannot be null"); + throw new InvalidOperationException("The query has already completed, it cannot be cancelled."); } - // Initialize the internal state - QueryText = queryText; - EditorConnection = connection; - cancellationSource = new CancellationTokenSource(); - - // Process the query into batches - ParseResult parseResult = Parser.Parse(queryText, new ParseOptions - { - BatchSeparator = settings.BatchSeparator - }); - // NOTE: We only want to process batches that have statements (ie, ignore comments and empty lines) - Batches = parseResult.Script.Batches.Where(b => b.Statements.Count > 0) - .Select(b => new Batch(b.Sql, b.StartLocation.LineNumber)).ToArray(); + // Issue the cancellation token for the query + cancellationSource.Cancel(); } /// @@ -141,7 +177,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution public async Task Execute() { // Mark that we've internally executed - HasExecuteBeenCalled = true; + hasExecuteBeenCalled = true; // Don't actually execute if there aren't any batches to execute if (Batches.Length == 0) @@ -150,8 +186,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution } // Open up a connection for querying the database - string connectionString = ConnectionService.BuildConnectionString(EditorConnection.ConnectionDetails); - using (DbConnection conn = EditorConnection.Factory.CreateSqlConnection(connectionString)) + string connectionString = ConnectionService.BuildConnectionString(editorConnection.ConnectionDetails); + // TODO: Don't create a new connection every time, see TFS #834978 + using (DbConnection conn = editorConnection.Factory.CreateSqlConnection(connectionString)) { await conn.OpenAsync(); @@ -167,6 +204,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { await b.Execute(conn, cancellationSource.Token); } + + // TODO: Close connection after eliminating using statement for above TODO } } @@ -176,13 +215,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution private void OnInfoMessage(object sender, SqlInfoMessageEventArgs args) { SqlConnection conn = sender as SqlConnection; + if (conn == null) + { + throw new InvalidOperationException("Sender for OnInfoMessage event must be a SqlConnection"); + } foreach(SqlError error in args.Errors) { // Did the database context change (error code 5701)? if (error.Number == DatabaseContextChangeErrorNumber) { - ConnectionService.Instance.ChangeConnectionDatabaseContext(EditorConnection.OwnerUri, conn.Database); + ConnectionService.Instance.ChangeConnectionDatabaseContext(editorConnection.OwnerUri, conn.Database); } } } @@ -195,7 +238,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution /// The starting row of the results /// How many rows to retrieve /// A subset of results - public ResultSetSubset GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) + public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount) { // Sanity check that the results are available if (!HasExecuted) @@ -213,25 +256,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution return Batches[batchIndex].GetSubset(resultSetIndex, startRow, rowCount); } - /// - /// Cancels the query by issuing the cancellation token - /// - public void Cancel() - { - // Make sure that the query hasn't completed execution - if (HasExecuted) - { - throw new InvalidOperationException("The query has already completed, it cannot be cancelled."); - } - - // Issue the cancellation token for the query - cancellationSource.Cancel(); - } + #endregion #region IDisposable Implementation - private bool disposed; - public void Dispose() { Dispose(true); @@ -248,16 +276,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution if (disposing) { cancellationSource.Dispose(); + foreach (Batch b in Batches) + { + b.Dispose(); + } } disposed = true; } - ~Query() - { - Dispose(false); - } - #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs index f23925ab..97d89fc9 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -10,6 +10,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Workspace; @@ -24,6 +25,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution private static readonly Lazy instance = new Lazy(() => new QueryExecutionService()); + /// + /// Singleton instance of the query execution service + /// public static QueryExecutionService Instance { get { return instance.Value; } @@ -43,6 +47,22 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution #region Properties + /// + /// File factory to be used to create a buffer file for results. + /// + /// + /// Made internal here to allow for overriding in unit testing + /// + internal IFileStreamFactory BufferFileStreamFactory; + + /// + /// File factory to be used to create a buffer file for results + /// + private IFileStreamFactory BufferFileFactory + { + get { return BufferFileStreamFactory ?? (BufferFileStreamFactory = new ServiceBufferFileStreamFactory()); } + } + /// /// The collection of active queries /// @@ -134,7 +154,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution var result = new QueryExecuteSubsetResult { Message = null, - ResultSubset = query.GetSubset(subsetParams.BatchIndex, + ResultSubset = await query.GetSubset(subsetParams.BatchIndex, subsetParams.ResultSetIndex, subsetParams.RowsStartIndex, subsetParams.RowsCount) }; await requestContext.SendResult(result); @@ -266,7 +286,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution QueryExecutionSettings settings = WorkspaceService.Instance.CurrentSettings.QueryExecutionSettings; // If we can't add the query now, it's assumed the query is in progress - Query newQuery = new Query(executeParams.QueryText, connectionInfo, settings); + Query newQuery = new Query(executeParams.QueryText, connectionInfo, settings, BufferFileFactory); if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) { await requestContext.SendResult(new QueryExecuteResult diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs index 058dc54c..84e18c99 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -1,4 +1,4 @@ -// +// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // @@ -7,46 +7,130 @@ using System; using System.Collections.Generic; using System.Data.Common; using System.Linq; +using System.Threading; +using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { - public class ResultSet + public class ResultSet : IDisposable { - public DbColumn[] Columns { get; set; } + #region Constants - public List Rows { get; private set; } + private const int DefaultMaxCharsToStore = 65535; // 64 KB - QE default - public ResultSet() - { - Rows = new List(); - } + // xml is a special case so number of chars to store is usually greater than for other long types + private const int DefaultMaxXmlCharsToStore = 2097152; // 2 MB - QE default + + #endregion + + #region Member Variables /// - /// Add a row of data to the result set using a that has already - /// read in a row. + /// For IDisposable pattern, whether or not object has been disposed /// - /// A that has already had a read performed - public void AddRow(DbDataReader reader) + private bool disposed; + + /// + /// The factory to use to get reading/writing handlers + /// + private readonly IFileStreamFactory fileStreamFactory; + + /// + /// File stream reader that will be reused to make rapid-fire retrieval of result subsets + /// quick and low perf impact. + /// + private IFileStreamReader fileStreamReader; + + /// + /// Whether or not the result set has been read in from the database + /// + private bool hasBeenRead; + + /// + /// The name of the temporary file we're using to output these results in + /// + private readonly string outputFileName; + + #endregion + + /// + /// Creates a new result set and initializes its state + /// + /// The reader from executing a query + /// Factory for creating a reader/writer + public ResultSet(DbDataReader reader, IFileStreamFactory factory) { - List row = new List(); - for (int i = 0; i < reader.FieldCount; ++i) + // Sanity check to make sure we got a reader + if (reader == null) { - row.Add(reader.GetValue(i)); + throw new ArgumentNullException(nameof(reader), "Reader cannot be null"); } - Rows.Add(row.ToArray()); + DataReader = new StorageDataReader(reader); + + // Initialize the storage + outputFileName = factory.CreateFile(); + FileOffsets = new LongList(); + + // Store the factory + fileStreamFactory = factory; + hasBeenRead = false; } + #region Properties + + /// + /// The columns for this result set + /// + public DbColumnWrapper[] Columns { get; private set; } + + /// + /// The reader to use for this resultset + /// + private StorageDataReader DataReader { get; set; } + + /// + /// A list of offsets into the buffer file that correspond to where rows start + /// + private LongList FileOffsets { get; set; } + + /// + /// Maximum number of characters to store for a field + /// + public int MaxCharsToStore { get { return DefaultMaxCharsToStore; } } + + /// + /// Maximum number of characters to store for an XML field + /// + public int MaxXmlCharsToStore { get { return DefaultMaxXmlCharsToStore; } } + + /// + /// The number of rows for this result set + /// + public long RowCount { get; private set; } + + #endregion + + #region Public Methods + /// /// Generates a subset of the rows from the result set /// /// The starting row of the results /// How many rows to retrieve /// A subset of results - public ResultSetSubset GetSubset(int startRow, int rowCount) + public Task GetSubset(int startRow, int rowCount) { + // Sanity check to make sure that the results have been read beforehand + if (!hasBeenRead || fileStreamReader == null) + { + throw new InvalidOperationException("Cannot read subset unless the results have been read from the server"); + } + // Sanity check to make sure that the row and the row count are within bounds - if (startRow < 0 || startRow >= Rows.Count) + if (startRow < 0 || startRow >= RowCount) { throw new ArgumentOutOfRangeException(nameof(startRow), "Start row cannot be less than 0 " + "or greater than the number of rows in the resultset"); @@ -56,13 +140,79 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution throw new ArgumentOutOfRangeException(nameof(rowCount), "Row count must be a positive integer"); } - // Retrieve the subset of the results as per the request - object[][] rows = Rows.Skip(startRow).Take(rowCount).ToArray(); - return new ResultSetSubset + return Task.Factory.StartNew(() => { - Rows = rows, - RowCount = rows.Length - }; + // Figure out which rows we need to read back + IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount); + + // Iterate over the rows we need and process them into output + object[][] rows = rowOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)).ToArray(); + + // Retrieve the subset of the results as per the request + return new ResultSetSubset + { + Rows = rows, + RowCount = rows.Length + }; + }); } + + /// + /// Reads from the reader until there are no more results to read + /// + /// Cancellation token for cancelling the query + public async Task ReadResultToEnd(CancellationToken cancellationToken) + { + // Open a writer for the file + using (IFileStreamWriter fileWriter = fileStreamFactory.GetWriter(outputFileName, MaxCharsToStore, MaxXmlCharsToStore)) + { + // If we can initialize the columns using the column schema, use that + if (!DataReader.DbDataReader.CanGetColumnSchema()) + { + throw new InvalidOperationException("Could not retrieve column schema for result set."); + } + Columns = DataReader.Columns; + long currentFileOffset = 0; + + while (await DataReader.ReadAsync(cancellationToken)) + { + RowCount++; + FileOffsets.Add(currentFileOffset); + currentFileOffset += fileWriter.WriteRow(DataReader); + } + } + + // Mark that result has been read + hasBeenRead = true; + fileStreamReader = fileStreamFactory.GetReader(outputFileName); + } + + #endregion + + #region IDisposable Implementation + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + fileStreamReader?.Dispose(); + fileStreamFactory.DisposeFile(outputFileName); + } + + disposed = true; + } + + #endregion } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs new file mode 100644 index 00000000..afacc98f --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs @@ -0,0 +1,259 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.Utility +{ + /// + /// Collection class that permits storage of over int.MaxValue items. This is performed + /// by using a 2D list of lists. The internal lists are only initialized as necessary. This + /// collection implements IEnumerable to make it easier to run LINQ queries against it. + /// + /// + /// This class is based on code from $\Data Tools\SSMS_Main\sql\ssms\core\DataStorage\ArrayList64.cs + /// with additions to bring it up to .NET 4.5 standards + /// + /// Type of the values to store + public class LongList : IEnumerable + { + #region Member Variables + + private List> expandedList; + private readonly List shortList; + + #endregion + + /// + /// Creates a new long list + /// + public LongList() + { + shortList = new List(); + Count = 0; + } + + #region Properties + + /// + /// The total number of elements in the array + /// + public long Count { get; private set; } + + public T this[long index] + { + get { return GetItem(index); } + } + + #endregion + + #region Public Methods + + /// + /// Adds the specified value to the end of the list + /// + /// Value to add to the list + /// Index of the item that was just added + public long Add(T val) + { + if (Count <= int.MaxValue) + { + shortList.Add(val); + } + else // need to split values into several arrays + { + if (expandedList == null) + { + // very inefficient so delay as much as possible + // immediately add 0th array + expandedList = new List> {shortList}; + } + + int arrayIndex = (int)(Count/int.MaxValue); // 0 based + + List arr; + if (expandedList.Count <= arrayIndex) // need to make a new array + { + arr = new List(); + expandedList.Add(arr); + } + else // use existing array + { + arr = expandedList[arrayIndex]; + } + arr.Add(val); + } + return (++Count); + } + + /// + /// Returns the item at the specified index + /// + /// Index of the item to return + /// The item at the index specified + public T GetItem(long index) + { + T val = default(T); + + if (Count <= int.MaxValue) + { + int i32Index = Convert.ToInt32(index); + val = shortList[i32Index]; + } + else + { + int iArray32Index = (int) (Count/int.MaxValue); + if (expandedList.Count > iArray32Index) + { + List arr = expandedList[iArray32Index]; + + int i32Index = (int) (Count%int.MaxValue); + if (arr.Count > i32Index) + { + val = arr[i32Index]; + } + } + } + return val; + } + + /// + /// Removes an item at the specified location and shifts all the items after the provided + /// index up by one. + /// + /// The index to remove from the list + public void RemoveAt(long index) + { + if (Count <= int.MaxValue) + { + int iArray32MemberIndex = Convert.ToInt32(index); // 0 based + shortList.RemoveAt(iArray32MemberIndex); + } + else // handle the case of multiple arrays + { + // find out which array it is in + int arrayIndex = (int) (index/int.MaxValue); + List arr = expandedList[arrayIndex]; + + // find out index into this array + int iArray32MemberIndex = (int) (index%int.MaxValue); + arr.RemoveAt(iArray32MemberIndex); + + // now shift members of the array back one + int iArray32TotalIndex = (int) (Count/Int32.MaxValue); + for (int i = arrayIndex + 1; i < iArray32TotalIndex; i++) + { + List arr1 = expandedList[i - 1]; + List arr2 = expandedList[i]; + + arr1.Add(arr2[int.MaxValue - 1]); + arr2.RemoveAt(0); + } + } + --Count; + } + + #endregion + + #region IEnumerable Implementation + + /// + /// Returns a generic enumerator for enumeration of this LongList + /// + /// Enumerator for LongList + public IEnumerator GetEnumerator() + { + return new LongListEnumerator(this); + } + + /// + /// Returns an enumerator for enumeration of this LongList + /// + /// + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + #endregion + + public class LongListEnumerator : IEnumerator + { + #region Member Variables + + /// + /// The index into the list of the item that is the current item + /// + private long index; + + /// + /// The current list that we're iterating over. + /// + private readonly LongList localList; + + #endregion + + /// + /// Constructs a new enumerator for a given LongList + /// + /// The list to enumerate + public LongListEnumerator(LongList list) + { + localList = list; + index = 0; + Current = default(TEt); + } + + #region IEnumerator Implementation + + /// + /// Returns the current item in the enumeration + /// + public TEt Current { get; private set; } + + object IEnumerator.Current + { + get { return Current; } + } + + /// + /// Moves to the next item in the list we're iterating over + /// + /// Whether or not the move was successful + public bool MoveNext() + { + if (index < localList.Count) + { + Current = localList[index]; + index++; + return true; + } + Current = default(TEt); + return false; + } + + /// + /// Resets the enumeration + /// + public void Reset() + { + index = 0; + Current = default(TEt); + } + + /// + /// Disposal method. Does nothing. + /// + public void Dispose() + { + } + + #endregion + } + } +} + diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs index 2437dc30..c7a00c09 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -3,9 +3,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using System; using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.IO; using System.Data.SqlClient; using System.Threading; using Microsoft.SqlTools.ServiceLayer.Connection; @@ -16,6 +18,7 @@ using Microsoft.SqlServer.Management.SqlParser.Binder; using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; @@ -71,7 +74,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public static Batch GetBasicExecutedBatch() { - Batch batch = new Batch(StandardQuery, 1); + Batch batch = new Batch(StandardQuery, 1, GetFileStreamFactory()); batch.Execute(CreateTestConnection(new[] {StandardTestData}, false), CancellationToken.None).Wait(); return batch; } @@ -79,11 +82,78 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public static Query GetBasicExecutedQuery() { ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); - Query query = new Query(StandardQuery, ci, new QueryExecutionSettings()); + Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory()); query.Execute().Wait(); return query; } + #region FileStreamWriteMocking + + public static IFileStreamFactory GetFileStreamFactory() + { + Mock mock = new Mock(); + mock.Setup(fsf => fsf.GetReader(It.IsAny())) + .Returns(new ServiceBufferFileStreamReader(new InMemoryWrapper(), It.IsAny())); + mock.Setup(fsf => fsf.GetWriter(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(new ServiceBufferFileStreamWriter(new InMemoryWrapper(), It.IsAny(), 1024, + 1024)); + + return mock.Object; + } + + public class InMemoryWrapper : IFileStreamWrapper + { + private readonly byte[] storage = new byte[8192]; + private readonly MemoryStream memoryStream; + private bool readingOnly; + + public InMemoryWrapper() + { + memoryStream = new MemoryStream(storage); + } + + public void Dispose() + { + // We'll dispose this via a special method + } + + public void Init(string fileName, int bufferSize, FileAccess fAccess) + { + readingOnly = fAccess == FileAccess.Read; + } + + public int ReadData(byte[] buffer, int bytes) + { + return ReadData(buffer, bytes, memoryStream.Position); + } + + public int ReadData(byte[] buffer, int bytes, long fileOffset) + { + memoryStream.Seek(fileOffset, SeekOrigin.Begin); + return memoryStream.Read(buffer, 0, bytes); + } + + public int WriteData(byte[] buffer, int bytes) + { + if (readingOnly) { throw new InvalidOperationException(); } + memoryStream.Write(buffer, 0, bytes); + memoryStream.Flush(); + return bytes; + } + + public void Flush() + { + if (readingOnly) { throw new InvalidOperationException(); } + } + + public void Close() + { + memoryStream.Dispose(); + } + } + + #endregion + #region DbConnection Mocking public static DbCommand CreateTestCommand(Dictionary[][] data, bool throwOnRead) @@ -151,12 +221,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution out ConnectionInfo connInfo ) { - textDocument = new TextDocumentPosition(); - textDocument.TextDocument = new TextDocumentIdentifier(); - textDocument.TextDocument.Uri = Common.OwnerUri; - textDocument.Position = new Position(); - textDocument.Position.Line = 0; - textDocument.Position.Character = 0; + textDocument = new TextDocumentPosition + { + TextDocument = new TextDocumentIdentifier {Uri = OwnerUri}, + Position = new Position + { + Line = 0, + Character = 0 + } + }; connInfo = Common.CreateTestConnectionInfo(null, false); @@ -166,15 +239,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution var binder = BinderProvider.CreateBinder(metadataProvider); LanguageService.Instance.ScriptParseInfoMap.Add(textDocument.TextDocument.Uri, - new ScriptParseInfo() + new ScriptParseInfo { Binder = binder, MetadataProvider = metadataProvider, MetadataDisplayInfoProvider = displayInfoProvider }); - scriptFile = new ScriptFile(); - scriptFile.ClientFilePath = textDocument.TextDocument.Uri; + scriptFile = new ScriptFile {ClientFilePath = textDocument.TextDocument.Uri}; + } public static ServerConnection GetServerConnection(ConnectionInfo connection) @@ -206,7 +279,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution OwnerUri = OwnerUri }); } - return new QueryExecutionService(connectionService); + return new QueryExecutionService(connectionService) {BufferFileStreamFactory = GetFileStreamFactory()}; } #endregion diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs new file mode 100644 index 00000000..f1a4cda0 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs @@ -0,0 +1,221 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage +{ + public class FileStreamWrapperTests + { + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + public void InitInvalidFilenameParameter(string fileName) + { + // If: + // ... I have a file stream wrapper that is initialized with invalid fileName + // Then: + // ... It should throw an argument null exception + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + Assert.Throws(() => fsw.Init(fileName, 8192, FileAccess.Read)); + } + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public void InitInvalidBufferLength(int bufferLength) + { + // If: + // ... I have a file stream wrapper that is initialized with an invalid buffer length + // Then: + // ... I should throw an argument out of range exception + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + Assert.Throws(() => fsw.Init("validFileName", bufferLength, FileAccess.Read)); + } + } + + [Fact] + public void InitInvalidFileAccessMode() + { + // If: + // ... I attempt to open a file stream wrapper that is initialized with an invalid file + // access mode + // Then: + // ... I should get an invalid argument exception + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + Assert.Throws(() => fsw.Init("validFileName", 8192, FileAccess.Write)); + } + } + + [Fact] + public void InitSuccessful() + { + string fileName = Path.GetTempFileName(); + + try + { + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + // If: + // ... I have a file stream wrapper that is initialized with valid parameters + fsw.Init(fileName, 8192, FileAccess.ReadWrite); + + // Then: + // ... The file should exist + FileInfo fileInfo = new FileInfo(fileName); + Assert.True(fileInfo.Exists); + + // ... The file should be marked as hidden + Assert.True((fileInfo.Attributes & FileAttributes.Hidden) != 0); + } + } + finally + { + // Cleanup: + // ... Delete the file that was created + try { File.Delete(fileName); } catch { /* Don't care */ } + } + } + + [Fact] + public void PerformOpWithoutInit() + { + byte[] buf = new byte[10]; + + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + // If: + // ... I have a file stream wrapper that hasn't been initialized + // Then: + // ... Attempting to perform any operation will result in an exception + Assert.Throws(() => fsw.ReadData(buf, 1)); + Assert.Throws(() => fsw.ReadData(buf, 1, 0)); + Assert.Throws(() => fsw.WriteData(buf, 1)); + Assert.Throws(() => fsw.Flush()); + } + } + + [Fact] + public void PerformWriteOpOnReadOnlyWrapper() + { + byte[] buf = new byte[10]; + + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + // If: + // ... I have a readonly file stream wrapper + // Then: + // ... Attempting to perform any write operation should result in an exception + Assert.Throws(() => fsw.WriteData(buf, 1)); + Assert.Throws(() => fsw.Flush()); + } + } + + [Theory] + [InlineData(1024, 20, 10)] // Standard scenario + [InlineData(1024, 100, 100)] // Requested more bytes than there are + [InlineData(5, 20, 10)] // Internal buffer too small, force a move-to operation + public void ReadData(int internalBufferLength, int outBufferLength, int requestedBytes) + { + // Setup: + // ... I have a file that has a handful of bytes in it + string fileName = Path.GetTempFileName(); + const string stringToWrite = "hello"; + CreateTestFile(fileName, stringToWrite); + byte[] targetBytes = Encoding.Unicode.GetBytes(stringToWrite); + + try + { + // If: + // ... I have a file stream wrapper that has been initialized to an existing file + // ... And I read some bytes from it + int bytesRead; + byte[] buf = new byte[outBufferLength]; + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + fsw.Init(fileName, internalBufferLength, FileAccess.Read); + bytesRead = fsw.ReadData(buf, targetBytes.Length); + } + + // Then: + // ... I should get those bytes back + Assert.Equal(targetBytes.Length, bytesRead); + Assert.True(targetBytes.Take(targetBytes.Length).SequenceEqual(buf.Take(targetBytes.Length))); + + } + finally + { + // Cleanup: + // ... Delete the test file + CleanupTestFile(fileName); + } + } + + [Theory] + [InlineData(1024)] // Standard scenario + [InlineData(10)] // Internal buffer too small, forces a flush + public void WriteData(int internalBufferLength) + { + string fileName = Path.GetTempFileName(); + byte[] bytesToWrite = Encoding.Unicode.GetBytes("hello"); + + try + { + // If: + // ... I have a file stream that has been initialized + // ... And I write some bytes to it + using (FileStreamWrapper fsw = new FileStreamWrapper()) + { + fsw.Init(fileName, internalBufferLength, FileAccess.ReadWrite); + int bytesWritten = fsw.WriteData(bytesToWrite, bytesToWrite.Length); + + Assert.Equal(bytesToWrite.Length, bytesWritten); + } + + // Then: + // ... The file I wrote to should contain only the bytes I wrote out + using (FileStream fs = File.OpenRead(fileName)) + { + byte[] readBackBytes = new byte[1024]; + int bytesRead = fs.Read(readBackBytes, 0, readBackBytes.Length); + + Assert.Equal(bytesToWrite.Length, bytesRead); // If bytes read is not equal, then more or less of the original string was written to the file + Assert.True(bytesToWrite.SequenceEqual(readBackBytes.Take(bytesRead))); + } + } + finally + { + // Cleanup: + // ... Delete the test file + CleanupTestFile(fileName); + } + } + + private static void CreateTestFile(string fileName, string value) + { + using (FileStream fs = new FileStream(fileName, FileMode.OpenOrCreate, FileAccess.ReadWrite)) + { + byte[] bytesToWrite = Encoding.Unicode.GetBytes(value); + fs.Write(bytesToWrite, 0, bytesToWrite.Length); + fs.Flush(); + } + } + + private static void CleanupTestFile(string fileName) + { + try { File.Delete(fileName); } catch { /* Don't Care */} + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs new file mode 100644 index 00000000..b10a7f92 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs @@ -0,0 +1,295 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlTypes; +using System.Text; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage +{ + public class ReaderWriterPairTest + { + private static void VerifyReadWrite(int valueLength, T value, Func writeFunc, Func> readFunc) + { + // Setup: Create a mock file stream wrapper + Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); + try + { + // If: + // ... I write a type T to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + { + int writtenBytes = writeFunc(writer, value); + Assert.Equal(valueLength, writtenBytes); + } + + // ... And read the type T back + FileStreamReadResult outValue; + using (ServiceBufferFileStreamReader reader = new ServiceBufferFileStreamReader(mockWrapper, "abc")) + { + outValue = readFunc(reader); + } + + // Then: + Assert.Equal(value, outValue.Value); + Assert.Equal(valueLength, outValue.TotalLength); + Assert.False(outValue.IsNull); + } + finally + { + // Cleanup: Close the wrapper + mockWrapper.Close(); + } + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + [InlineData(-10)] + [InlineData(short.MaxValue)] // Two byte number + [InlineData(short.MinValue)] // Negative two byte number + public void Int16(short value) + { + VerifyReadWrite(sizeof(short) + 1, value, (writer, val) => writer.WriteInt16(val), reader => reader.ReadInt16(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + [InlineData(-10)] + [InlineData(short.MaxValue)] // Two byte number + [InlineData(short.MinValue)] // Negative two byte number + [InlineData(int.MaxValue)] // Four byte number + [InlineData(int.MinValue)] // Negative four byte number + public void Int32(int value) + { + VerifyReadWrite(sizeof(int) + 1, value, (writer, val) => writer.WriteInt32(val), reader => reader.ReadInt32(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + [InlineData(-10)] + [InlineData(short.MaxValue)] // Two byte number + [InlineData(short.MinValue)] // Negative two byte number + [InlineData(int.MaxValue)] // Four byte number + [InlineData(int.MinValue)] // Negative four byte number + [InlineData(long.MaxValue)] // Eight byte number + [InlineData(long.MinValue)] // Negative eight byte number + public void Int64(long value) + { + VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteInt64(val), reader => reader.ReadInt64(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10)] + public void Byte(byte value) + { + VerifyReadWrite(sizeof(byte) + 1, value, (writer, val) => writer.WriteByte(val), reader => reader.ReadByte(0)); + } + + [Theory] + [InlineData('a')] + [InlineData('1')] + [InlineData((char)0x9152)] // Test something in the UTF-16 space + public void Char(char value) + { + VerifyReadWrite(sizeof(char) + 1, value, (writer, val) => writer.WriteChar(val), reader => reader.ReadChar(0)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Boolean(bool value) + { + VerifyReadWrite(sizeof(bool) + 1, value, (writer, val) => writer.WriteBoolean(val), reader => reader.ReadBoolean(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10.1)] + [InlineData(-10.1)] + [InlineData(float.MinValue)] + [InlineData(float.MaxValue)] + [InlineData(float.PositiveInfinity)] + [InlineData(float.NegativeInfinity)] + public void Single(float value) + { + VerifyReadWrite(sizeof(float) + 1, value, (writer, val) => writer.WriteSingle(val), reader => reader.ReadSingle(0)); + } + + [Theory] + [InlineData(0)] + [InlineData(10.1)] + [InlineData(-10.1)] + [InlineData(float.MinValue)] + [InlineData(float.MaxValue)] + [InlineData(float.PositiveInfinity)] + [InlineData(float.NegativeInfinity)] + [InlineData(double.PositiveInfinity)] + [InlineData(double.NegativeInfinity)] + [InlineData(double.MinValue)] + [InlineData(double.MaxValue)] + public void Double(double value) + { + VerifyReadWrite(sizeof(double) + 1, value, (writer, val) => writer.WriteDouble(val), reader => reader.ReadDouble(0)); + } + + [Fact] + public void SqlDecimalTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because SqlDecimal values can't be written as constant expressions + SqlDecimal[] testValues = + { + SqlDecimal.MaxValue, SqlDecimal.MinValue, new SqlDecimal(0x01, 0x01, true, 0, 0, 0, 0) + }; + foreach (SqlDecimal value in testValues) + { + int valueLength = 4 + value.BinData.Length; + VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteSqlDecimal(val), reader => reader.ReadSqlDecimal(0)); + } + } + + [Fact] + public void Decimal() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because Decimal values can't be written as constant expressions + decimal[] testValues = + { + decimal.Zero, decimal.One, decimal.MinusOne, decimal.MinValue, decimal.MaxValue + }; + + foreach (decimal value in testValues) + { + int valueLength = decimal.GetBits(value).Length*4 + 1; + VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteDecimal(val), reader => reader.ReadDecimal(0)); + } + } + + [Fact] + public void DateTimeTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because DateTime values can't be written as constant expressions + DateTime[] testValues = + { + DateTime.Now, DateTime.UtcNow, DateTime.MinValue, DateTime.MaxValue + }; + foreach (DateTime value in testValues) + { + VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteDateTime(val), reader => reader.ReadDateTime(0)); + } + } + + [Fact] + public void DateTimeOffsetTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because DateTimeOffset values can't be written as constant expressions + DateTimeOffset[] testValues = + { + DateTimeOffset.Now, DateTimeOffset.UtcNow, DateTimeOffset.MinValue, DateTimeOffset.MaxValue + }; + foreach (DateTimeOffset value in testValues) + { + VerifyReadWrite((sizeof(long) + 1)*2, value, (writer, val) => writer.WriteDateTimeOffset(val), reader => reader.ReadDateTimeOffset(0)); + } + } + + [Fact] + public void TimeSpanTest() + { + // Setup: Create some test values + // NOTE: We are doing these here instead of InlineData because TimeSpan values can't be written as constant expressions + TimeSpan[] testValues = + { + TimeSpan.Zero, TimeSpan.MinValue, TimeSpan.MaxValue, TimeSpan.FromMinutes(60) + }; + foreach (TimeSpan value in testValues) + { + VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteTimeSpan(val), reader => reader.ReadTimeSpan(0)); + } + } + + [Fact] + public void StringNullTest() + { + // Setup: Create a mock file stream wrapper + Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); + + // If: + // ... I write null as a string to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + { + // Then: + // ... I should get an argument null exception + Assert.Throws(() => writer.WriteString(null)); + } + } + + [Theory] + [InlineData(0, null)] // Test of empty string + [InlineData(1, new[] { 'j' })] + [InlineData(1, new[] { (char)0x9152 })] + [InlineData(100, new[] { 'j', (char)0x9152 })] // Test alternating utf-16/ascii characters + [InlineData(512, new[] { 'j', (char)0x9152 })] // Test that requires a 4 byte length + public void StringTest(int length, char[] values) + { + // Setup: + // ... Generate the test value + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < length; i++) + { + sb.Append(values[i%values.Length]); + } + string value = sb.ToString(); + int lengthLength = length == 0 || length > 255 ? 5 : 1; + VerifyReadWrite(sizeof(char)*length + lengthLength, value, (writer, val) => writer.WriteString(value), reader => reader.ReadString(0)); + } + + [Fact] + public void BytesNullTest() + { + // Setup: Create a mock file stream wrapper + Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper(); + + // If: + // ... I write null as a string to the writer + using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10)) + { + // Then: + // ... I should get an argument null exception + Assert.Throws(() => writer.WriteBytes(null, 0)); + } + } + + [Theory] + [InlineData(0, new byte[] { 0x00 })] // Test of empty byte[] + [InlineData(1, new byte[] { 0x00 })] + [InlineData(1, new byte[] { 0xFF })] + [InlineData(100, new byte[] { 0x10, 0xFF, 0x00 })] + [InlineData(512, new byte[] { 0x10, 0xFF, 0x00 })] // Test that requires a 4 byte length + public void Bytes(int length, byte[] values) + { + // Setup: + // ... Generate the test value + List sb = new List(); + for (int i = 0; i < length; i++) + { + sb.Add(values[i % values.Length]); + } + byte[] value = sb.ToArray(); + int lengthLength = length == 0 || length > 255 ? 5 : 1; + int valueLength = sizeof(byte)*length + lengthLength; + VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteBytes(value, length), reader => reader.ReadBytes(0)); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index 49b6c76c..bda6ca0d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -29,7 +29,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public void BatchCreationTest() { // If I create a new batch... - Batch batch = new Batch(Common.StandardQuery, 1); + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); // Then: // ... The text of the batch should be stored @@ -52,7 +52,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public void BatchExecuteNoResultSets() { // If I execute a query that should get no result sets - Batch batch = new Batch(Common.StandardQuery, 1); + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); batch.Execute(GetConnection(Common.CreateTestConnectionInfo(null, false)), CancellationToken.None).Wait(); // Then: @@ -79,7 +79,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); // If I execute a query that should get one result set - Batch batch = new Batch(Common.StandardQuery, 1); + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); // Then: @@ -92,11 +92,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Assert.Equal(resultSets, batch.ResultSummaries.Length); // ... Inside the result set should be with 5 rows - Assert.Equal(Common.StandardRows, batch.ResultSets.First().Rows.Count); + Assert.Equal(Common.StandardRows, batch.ResultSets.First().RowCount); Assert.Equal(Common.StandardRows, batch.ResultSummaries[0].RowCount); - // ... Inside the result set should have 5 columns and 5 column definitions - Assert.Equal(Common.StandardColumns, batch.ResultSets.First().Rows[0].Length); + // ... Inside the result set should have 5 columns Assert.Equal(Common.StandardColumns, batch.ResultSets.First().Columns.Length); Assert.Equal(Common.StandardColumns, batch.ResultSummaries[0].ColumnInfo.Length); @@ -112,7 +111,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution ConnectionInfo ci = Common.CreateTestConnectionInfo(dataset, false); // If I execute a query that should get two result sets - Batch batch = new Batch(Common.StandardQuery, 1); + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); // Then: @@ -126,10 +125,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution foreach (ResultSet rs in batch.ResultSets) { // ... Each result set should have 5 rows - Assert.Equal(Common.StandardRows, rs.Rows.Count); + Assert.Equal(Common.StandardRows, rs.RowCount); - // ... Inside each result set should be 5 columns and 5 column definitions - Assert.Equal(Common.StandardColumns, rs.Rows[0].Length); + // ... Inside each result set should be 5 columns Assert.Equal(Common.StandardColumns, rs.Columns.Length); } @@ -155,7 +153,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); // If I execute a batch that is invalid - Batch batch = new Batch(Common.StandardQuery, 1); + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); // Then: @@ -177,7 +175,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false); // If I execute a batch - Batch batch = new Batch(Common.StandardQuery, 1); + Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory()); batch.Execute(GetConnection(ci), CancellationToken.None).Wait(); // Then: @@ -207,7 +205,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... I create a batch that has an empty query // Then: // ... It should throw an exception - Assert.Throws(() => new Batch(query, 1)); + Assert.Throws(() => new Batch(query, 1, Common.GetFileStreamFactory())); + } + + [Fact] + public void BatchNoBufferFactory() + { + // If: + // ... I create a batch that has no file stream factory + // Then: + // ... It should throw an exception + Assert.Throws(() => new Batch("stuff", 1, null)); } #endregion @@ -222,7 +230,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Then: // ... It should throw an exception Assert.Throws(() => - new Query(null, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings())); + new Query(null, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory())); } [Fact] @@ -232,7 +240,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... I create a query that has a null connection info // Then: // ... It should throw an exception - Assert.Throws(() => new Query("Some Query", null, new QueryExecutionSettings())); + Assert.Throws(() => new Query("Some Query", null, new QueryExecutionSettings(), Common.GetFileStreamFactory())); } [Fact] @@ -243,7 +251,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // Then: // ... It should throw an exception Assert.Throws(() => - new Query("Some query", Common.CreateTestConnectionInfo(null, false), null)); + new Query("Some query", Common.CreateTestConnectionInfo(null, false), null, Common.GetFileStreamFactory())); + } + + [Fact] + public void QueryExecuteNoBufferFactory() + { + // If: + // ... I create a query that has a null file stream factory + // Then: + // ... It should throw an exception + Assert.Throws(() => + new Query("Some query", Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(),null)); } [Fact] @@ -252,7 +271,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I create a query from a single batch (without separator) ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); - Query query = new Query(Common.StandardQuery, ci, new QueryExecutionSettings()); + Query query = new Query(Common.StandardQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); // Then: // ... I should get a single batch to execute that hasn't been executed @@ -279,7 +298,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I create a query from a single batch that does nothing ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); - Query query = new Query(Common.NoOpQuery, ci, new QueryExecutionSettings()); + Query query = new Query(Common.NoOpQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); // Then: // ... I should get no batches back @@ -305,7 +324,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... I create a query from two batches (with separator) ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery); - Query query = new Query(queryText, ci, new QueryExecutionSettings()); + Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); // Then: // ... I should get back two batches to execute that haven't been executed @@ -333,7 +352,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... I create a query from a two batches (with separator) ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false); string queryText = string.Format("{0}\r\nGO\r\n{1}", Common.StandardQuery, Common.NoOpQuery); - Query query = new Query(queryText, ci, new QueryExecutionSettings()); + Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); // Then: // ... I should get back one batch to execute that hasn't been executed @@ -359,7 +378,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // If: // ... I create a query from an invalid batch ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); - Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings()); + Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory()); // Then: // ... I should get back a query with one batch not executed diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs index a6f5e9fe..2968e709 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -28,7 +28,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Batch b = Common.GetBasicExecutedBatch(); // ... And I ask for a subset with valid arguments - ResultSetSubset subset = b.GetSubset(0, 0, rowCount); + ResultSetSubset subset = b.GetSubset(0, 0, rowCount).Result; // Then: // I should get the requested number of rows @@ -51,7 +51,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... And I ask for a subset with an invalid result set index // Then: // ... It should throw an exception - Assert.Throws(() => b.GetSubset(resultSetIndex, rowStartInex, rowCount)); + Assert.ThrowsAsync(() => b.GetSubset(resultSetIndex, rowStartInex, rowCount)).Wait(); } #endregion @@ -62,12 +62,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution public void SubsetUnexecutedQueryTest() { // If I have a query that has *not* been executed - Query q = new Query(Common.StandardQuery, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings()); + Query q = new Query(Common.StandardQuery, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory()); // ... And I ask for a subset with valid arguments // Then: // ... It should throw an exception - Assert.Throws(() => q.GetSubset(0, 0, 0, 2)); + Assert.ThrowsAsync(() => q.GetSubset(0, 0, 0, 2)).Wait(); } [Theory] @@ -81,7 +81,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution // ... And I ask for a subset with an invalid result set index // Then: // ... It should throw an exception - Assert.Throws(() => q.GetSubset(batchIndex, 0, 0, 1)); + Assert.ThrowsAsync(() => q.GetSubset(batchIndex, 0, 0, 1)).Wait(); } #endregion diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs new file mode 100644 index 00000000..c2765783 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs @@ -0,0 +1,21 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Data.Common; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public class TestDbColumn : DbColumn + { + public TestDbColumn() + { + base.IsLong = false; + base.ColumnName = "Test Column"; + base.ColumnSize = 128; + base.AllowDBNull = true; + base.DataType = typeof(string); + base.DataTypeName = "nvarchar"; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs index e2003789..0330cda0 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs @@ -64,6 +64,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility return this[ordinal]; } + public override int GetValues(object[] values) + { + for(int i = 0; i < Rows.Current.Count; i++) + { + values[i] = this[i]; + } + return Rows.Current.Count; + } + public override object this[string name] { get { return Rows.Current[name]; } @@ -84,11 +93,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility List columns = new List(); for (int i = 0; i < ResultSet.Current[0].Count; i++) { - columns.Add(new Mock().Object); + columns.Add(new TestDbColumn()); } return new ReadOnlyCollection(columns); } + public override bool IsDBNull(int ordinal) + { + return this[ordinal] == null; + } + public override int FieldCount { get { return Rows?.Current.Count ?? 0; } } public override int RecordsAffected @@ -189,16 +203,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility throw new NotImplementedException(); } - public override int GetValues(object[] values) - { - throw new NotImplementedException(); - } - - public override bool IsDBNull(int ordinal) - { - throw new NotImplementedException(); - } - public override IEnumerator GetEnumerator() { throw new NotImplementedException();