diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs
index 38528b60..69250afe 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Batch.cs
@@ -1,7 +1,6 @@
-//
+//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
-//
using System;
using System.Collections.Generic;
@@ -11,18 +10,60 @@ using System.Data.SqlClient;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.SqlTools.EditorServices.Utility;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
///
/// This class represents a batch within a query
///
- public class Batch
+ public class Batch : IDisposable
{
private const string RowsAffectedFormat = "({0} row(s) affected)";
+ #region Member Variables
+
+ ///
+ /// For IDisposable implementation, whether or not this has been disposed
+ ///
+ private bool disposed;
+
+ ///
+ /// Factory for creating readers/writrs for the output of the batch
+ ///
+ private readonly IFileStreamFactory outputFileFactory;
+
+ ///
+ /// Internal representation of the messages so we can modify internally
+ ///
+ private readonly List resultMessages;
+
+ ///
+ /// Internal representation of the result sets so we can modify internally
+ ///
+ private readonly List resultSets;
+
+ #endregion
+
+ internal Batch(string batchText, int startLine, IFileStreamFactory outputFileFactory)
+ {
+ // Sanity check for input
+ Validate.IsNotNullOrEmptyString(nameof(batchText), batchText);
+ Validate.IsNotNull(nameof(outputFileFactory), outputFileFactory);
+
+ // Initialize the internal state
+ BatchText = batchText;
+ StartLine = startLine - 1; // -1 to make sure that the line number of the batch is 0-indexed, since SqlParser gives 1-indexed line numbers
+ HasExecuted = false;
+ resultSets = new List();
+ resultMessages = new List();
+ this.outputFileFactory = outputFileFactory;
+ }
+
#region Properties
+
///
/// The text of batch that will be executed
///
@@ -38,11 +79,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
public bool HasExecuted { get; set; }
- ///
- /// Internal representation of the messages so we can modify internally
- ///
- private List resultMessages;
-
///
/// Messages that have come back from the server
///
@@ -51,11 +87,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
get { return resultMessages; }
}
- ///
- /// Internal representation of the result sets so we can modify internally
- ///
- private List resultSets;
-
///
/// The result sets of the batch execution
///
@@ -75,7 +106,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
ColumnInfo = set.Columns,
Id = index,
- RowCount = set.Rows.Count
+ RowCount = set.RowCount
}).ToArray();
}
}
@@ -87,21 +118,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
#endregion
- public Batch(string batchText, int startLine)
- {
- // Sanity check for input
- if (string.IsNullOrEmpty(batchText))
- {
- throw new ArgumentNullException(nameof(batchText), "Query text cannot be null");
- }
-
- // Initialize the internal state
- BatchText = batchText;
- StartLine = startLine - 1; // -1 to make sure that the line number of the batch is 0-indexed, since SqlParser gives 1-indexed line numbers
- HasExecuted = false;
- resultSets = new List();
- resultMessages = new List();
- }
+ #region Public Methods
///
/// Executes this batch and captures any server messages that are returned.
@@ -148,23 +165,14 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
// Read until we hit the end of the result set
- ResultSet resultSet = new ResultSet();
- while (await reader.ReadAsync(cancellationToken))
- {
- resultSet.AddRow(reader);
- }
-
- // Read off the column schema information
- if (reader.CanGetColumnSchema())
- {
- resultSet.Columns = reader.GetColumnSchema().ToArray();
- }
+ ResultSet resultSet = new ResultSet(reader, outputFileFactory);
+ await resultSet.ReadResultToEnd(cancellationToken);
// Add the result set to the results of the query
resultSets.Add(resultSet);
// Add a message for the number of rows the query returned
- resultMessages.Add(string.Format(RowsAffectedFormat, resultSet.Rows.Count));
+ resultMessages.Add(string.Format(RowsAffectedFormat, resultSet.RowCount));
} while (await reader.NextResultAsync(cancellationToken));
}
}
@@ -200,7 +208,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// The starting row of the results
/// How many rows to retrieve
/// A subset of results
- public ResultSetSubset GetSubset(int resultSetIndex, int startRow, int rowCount)
+ public Task GetSubset(int resultSetIndex, int startRow, int rowCount)
{
// Sanity check to make sure we have valid numbers
if (resultSetIndex < 0 || resultSetIndex >= resultSets.Count)
@@ -213,6 +221,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
return resultSets[resultSetIndex].GetSubset(startRow, rowCount);
}
+ #endregion
+
#region Private Helpers
///
@@ -259,5 +269,33 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
#endregion
+
+ #region IDisposable Implementation
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposed)
+ {
+ return;
+ }
+
+ if (disposing)
+ {
+ foreach (ResultSet r in ResultSets)
+ {
+ r.Dispose();
+ }
+ }
+
+ disposed = true;
+ }
+
+ #endregion
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs
new file mode 100644
index 00000000..e80eada5
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/DbColumnWrapper.cs
@@ -0,0 +1,226 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Data.Common;
+using System.Data.SqlTypes;
+using System.Diagnostics;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts
+{
+ ///
+ /// Wrapper around a DbColumn, which provides extra functionality, but can be used as a
+ /// regular DbColumn
+ ///
+ public class DbColumnWrapper : DbColumn
+ {
+ ///
+ /// All types supported by the server, stored as a hash set to provide O(1) lookup
+ ///
+ private static readonly HashSet AllServerDataTypes = new HashSet
+ {
+ "bigint",
+ "binary",
+ "bit",
+ "char",
+ "datetime",
+ "decimal",
+ "float",
+ "image",
+ "int",
+ "money",
+ "nchar",
+ "ntext",
+ "nvarchar",
+ "real",
+ "uniqueidentifier",
+ "smalldatetime",
+ "smallint",
+ "smallmoney",
+ "text",
+ "timestamp",
+ "tinyint",
+ "varbinary",
+ "varchar",
+ "sql_variant",
+ "xml",
+ "date",
+ "time",
+ "datetimeoffset",
+ "datetime2"
+ };
+
+ private readonly DbColumn internalColumn;
+
+ ///
+ /// Constructor for a DbColumnWrapper
+ ///
+ /// Most of this logic is taken from SSMS ColumnInfo class
+ /// The column we're wrapping around
+ public DbColumnWrapper(DbColumn column)
+ {
+ internalColumn = column;
+
+ switch (column.DataTypeName)
+ {
+ case "varchar":
+ case "nvarchar":
+ IsChars = true;
+
+ Debug.Assert(column.ColumnSize.HasValue);
+ if (column.ColumnSize.Value == int.MaxValue)
+ {
+ //For Yukon, special case nvarchar(max) with column name == "Microsoft SQL Server 2005 XML Showplan" -
+ //assume it is an XML showplan.
+ //Please note this field must be in sync with a similar field defined in QESQLBatch.cs.
+ //This is not the best fix that we could do but we are trying to minimize code impact
+ //at this point. Post Yukon we should review this code again and avoid
+ //hard-coding special column name in multiple places.
+ const string YukonXmlShowPlanColumn = "Microsoft SQL Server 2005 XML Showplan";
+ if (column.ColumnName == YukonXmlShowPlanColumn)
+ {
+ // Indicate that this is xml to apply the right size limit
+ // Note we leave chars type as well to use the right retrieval mechanism.
+ IsXml = true;
+ }
+ IsLong = true;
+ }
+ break;
+ case "text":
+ case "ntext":
+ IsChars = true;
+ IsLong = true;
+ break;
+ case "xml":
+ IsXml = true;
+ IsLong = true;
+ break;
+ case "binary":
+ case "image":
+ IsBytes = true;
+ IsLong = true;
+ break;
+ case "varbinary":
+ case "rowversion":
+ IsBytes = true;
+
+ Debug.Assert(column.ColumnSize.HasValue);
+ if (column.ColumnSize.Value == int.MaxValue)
+ {
+ IsLong = true;
+ }
+ break;
+ case "sql_variant":
+ IsSqlVariant = true;
+ break;
+ default:
+ if (!AllServerDataTypes.Contains(column.DataTypeName))
+ {
+ // treat all UDT's as long/bytes data types to prevent the CLR from attempting
+ // to load the UDT assembly into our process to call ToString() on the object.
+
+ IsUdt = true;
+ IsBytes = true;
+ IsLong = true;
+ }
+ break;
+ }
+
+
+ if (IsUdt)
+ {
+ // udtassemblyqualifiedname property is used to find if the datatype is of hierarchyid assembly type
+ // Internally hiearchyid is sqlbinary so providerspecific type and type is changed to sqlbinarytype
+ object assemblyQualifiedName = internalColumn.UdtAssemblyQualifiedName;
+ const string hierarchyId = "MICROSOFT.SQLSERVER.TYPES.SQLHIERARCHYID";
+
+ if (assemblyQualifiedName != null &&
+ string.Equals(assemblyQualifiedName.ToString(), hierarchyId, StringComparison.OrdinalIgnoreCase))
+ {
+ DataType = typeof(SqlBinary);
+ }
+ else
+ {
+ DataType = typeof(byte[]);
+ }
+ }
+ else
+ {
+ DataType = DataType;
+ }
+ }
+
+ #region Properties
+
+ ///
+ /// Whether or not the column is bytes
+ ///
+ public bool IsBytes { get; private set; }
+
+ ///
+ /// Whether or not the column is a character type
+ ///
+ public bool IsChars { get; private set; }
+
+ ///
+ /// Whether or not the column is a long type (eg, varchar(MAX))
+ ///
+ public new bool IsLong { get; private set; }
+
+ ///
+ /// Whether or not the column is a SqlVariant type
+ ///
+ public bool IsSqlVariant { get; private set; }
+
+ ///
+ /// Whether or not the column is a user-defined type
+ ///
+ public bool IsUdt { get; private set; }
+
+ ///
+ /// Whether or not the column is XML
+ ///
+ public bool IsXml { get; private set; }
+
+ #endregion
+
+ #region DbColumn Fields
+
+ ///
+ /// Override for column name, if null or empty, we default to a "no column name" value
+ ///
+ public new string ColumnName
+ {
+ get
+ {
+ // TODO: Localize
+ return string.IsNullOrEmpty(internalColumn.ColumnName) ? "(No column name)" : internalColumn.ColumnName;
+ }
+ }
+
+ public new bool? AllowDBNull { get { return internalColumn.AllowDBNull; } }
+ public new string BaseCatalogName { get { return internalColumn.BaseCatalogName; } }
+ public new string BaseColumnName { get { return internalColumn.BaseColumnName; } }
+ public new string BaseServerName { get { return internalColumn.BaseServerName; } }
+ public new string BaseTableName { get { return internalColumn.BaseTableName; } }
+ public new int? ColumnOrdinal { get { return internalColumn.ColumnOrdinal; } }
+ public new int? ColumnSize { get { return internalColumn.ColumnSize; } }
+ public new bool? IsAliased { get { return internalColumn.IsAliased; } }
+ public new bool? IsAutoIncrement { get { return internalColumn.IsAutoIncrement; } }
+ public new bool? IsExpression { get { return internalColumn.IsExpression; } }
+ public new bool? IsHidden { get { return internalColumn.IsHidden; } }
+ public new bool? IsIdentity { get { return internalColumn.IsIdentity; } }
+ public new bool? IsKey { get { return internalColumn.IsKey; } }
+ public new bool? IsReadOnly { get { return internalColumn.IsReadOnly; } }
+ public new bool? IsUnique { get { return internalColumn.IsUnique; } }
+ public new int? NumericPrecision { get { return internalColumn.NumericPrecision; } }
+ public new int? NumericScale { get { return internalColumn.NumericScale; } }
+ public new string UdtAssemblyQualifiedName { get { return internalColumn.UdtAssemblyQualifiedName; } }
+ public new Type DataType { get; private set; }
+ public new string DataTypeName { get { return internalColumn.DataTypeName; } }
+
+ #endregion
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs
index b0a6d75c..c8705d8b 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs
@@ -3,8 +3,6 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
-using System.Data.Common;
-
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts
{
///
@@ -20,11 +18,11 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts
///
/// The number of rows that was returned with the resultset
///
- public int RowCount { get; set; }
+ public long RowCount { get; set; }
///
/// Details about the columns that are provided as solutions
///
- public DbColumn[] ColumnInfo { get; set; }
+ public DbColumnWrapper[] ColumnInfo { get; set; }
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs
new file mode 100644
index 00000000..61ee62e0
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamReadResult.cs
@@ -0,0 +1,50 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Represents a value returned from a read from a file stream. This is used to eliminate ref
+ /// parameters used in the read methods.
+ ///
+ /// The type of the value that was read
+ public struct FileStreamReadResult
+ {
+ ///
+ /// Whether or not the value of the field is null
+ ///
+ public bool IsNull { get; set; }
+
+ ///
+ /// The value of the field. If is true, this will be set to default(T)
+ ///
+ public T Value { get; set; }
+
+ ///
+ /// The total length in bytes of the value, (including the bytes used to store the length
+ /// of the value)
+ ///
+ ///
+ /// Cell values are stored such that the length of the value is stored first, then the
+ /// value itself is stored. Eg, a string may be stored as 0x03 0x6C 0x6F 0x6C. Under this
+ /// system, the value would be "lol", the length would be 3, and the total length would be
+ /// 4 bytes.
+ ///
+ public int TotalLength { get; set; }
+
+ ///
+ /// Constructs a new FileStreamReadResult
+ ///
+ /// The value of the result
+ /// The number of bytes for the used to store the value's length and value
+ /// Whether or not the value is null
+ public FileStreamReadResult(T value, int totalLength, bool isNull)
+ {
+ Value = value;
+ TotalLength = totalLength;
+ IsNull = isNull;
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs
new file mode 100644
index 00000000..afe616f3
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/FileStreamWrapper.cs
@@ -0,0 +1,282 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Diagnostics;
+using System.IO;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Wrapper for a file stream, providing simplified creation, deletion, read, and write
+ /// functionality.
+ ///
+ public class FileStreamWrapper : IFileStreamWrapper
+ {
+ #region Member Variables
+
+ private byte[] buffer;
+ private int bufferDataSize;
+ private FileStream fileStream;
+ private long startOffset;
+ private long currentOffset;
+
+ #endregion
+
+ ///
+ /// Constructs a new FileStreamWrapper and initializes its state.
+ ///
+ public FileStreamWrapper()
+ {
+ // Initialize the internal state
+ bufferDataSize = 0;
+ startOffset = 0;
+ currentOffset = 0;
+ }
+
+ #region IFileStreamWrapper Implementation
+
+ ///
+ /// Initializes the wrapper by creating the internal buffer and opening the requested file.
+ /// If the file does not already exist, it will be created.
+ ///
+ /// Name of the file to open/create
+ /// The length of the internal buffer
+ ///
+ /// Whether or not the wrapper will be used for reading. If true, any calls to a
+ /// method that writes will cause an InvalidOperationException
+ ///
+ public void Init(string fileName, int bufferLength, FileAccess accessMethod)
+ {
+ // Sanity check for valid buffer length, fileName, and accessMethod
+ if (bufferLength <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(bufferLength), "Buffer length must be a positive value");
+ }
+ if (string.IsNullOrWhiteSpace(fileName))
+ {
+ throw new ArgumentNullException(nameof(fileName), "File name cannot be null or whitespace");
+ }
+ if (accessMethod == FileAccess.Write)
+ {
+ throw new ArgumentException("Access method cannot be write-only", nameof(fileName));
+ }
+
+ // Setup the buffer
+ buffer = new byte[bufferLength];
+
+ // Open the requested file for reading/writing, creating one if it doesn't exist
+ fileStream = new FileStream(fileName, FileMode.OpenOrCreate, accessMethod, FileShare.ReadWrite,
+ bufferLength, false /*don't use asyncio*/);
+
+ // make file hidden
+ FileInfo fileInfo = new FileInfo(fileName);
+ fileInfo.Attributes |= FileAttributes.Hidden;
+ }
+
+ ///
+ /// Reads data into a buffer from the current offset into the file
+ ///
+ /// The buffer to output the read data to
+ /// The number of bytes to read into the buffer
+ /// The number of bytes read
+ public int ReadData(byte[] buf, int bytes)
+ {
+ return ReadData(buf, bytes, currentOffset);
+ }
+
+ ///
+ /// Reads data into a buffer from the specified offset into the file
+ ///
+ /// The buffer to output the read data to
+ /// The number of bytes to read into the buffer
+ /// The offset into the file to start reading bytes from
+ /// The number of bytes read
+ public int ReadData(byte[] buf, int bytes, long offset)
+ {
+ // Make sure that we're initialized before performing operations
+ if (buffer == null)
+ {
+ throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations");
+ }
+
+ MoveTo(offset);
+
+ int bytesCopied = 0;
+ while (bytesCopied < bytes)
+ {
+ int bufferOffset, bytesToCopy;
+ GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy);
+ Buffer.BlockCopy(buffer, bufferOffset, buf, bytesCopied, bytesToCopy);
+ bytesCopied += bytesToCopy;
+
+ if (bytesCopied < bytes && // did not get all the bytes yet
+ bufferDataSize == buffer.Length) // since current data buffer is full we should continue reading the file
+ {
+ // move forward one full length of the buffer
+ MoveTo(startOffset + buffer.Length);
+ }
+ else
+ {
+ // copied all the bytes requested or possible, adjust the current buffer pointer
+ currentOffset += bytesToCopy;
+ break;
+ }
+ }
+ return bytesCopied;
+ }
+
+ ///
+ /// Writes data to the underlying filestream, with buffering.
+ ///
+ /// The buffer of bytes to write to the filestream
+ /// The number of bytes to write
+ /// The number of bytes written
+ public int WriteData(byte[] buf, int bytes)
+ {
+ // Make sure that we're initialized before performing operations
+ if (buffer == null)
+ {
+ throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations");
+ }
+ if (!fileStream.CanWrite)
+ {
+ throw new InvalidOperationException("This FileStreamWrapper canot be used for writing");
+ }
+
+ int bytesCopied = 0;
+ while (bytesCopied < bytes)
+ {
+ int bufferOffset, bytesToCopy;
+ GetByteCounts(bytes, bytesCopied, out bufferOffset, out bytesToCopy);
+ Buffer.BlockCopy(buf, bytesCopied, buffer, bufferOffset, bytesToCopy);
+ bytesCopied += bytesToCopy;
+
+ // adjust the current buffer pointer
+ currentOffset += bytesToCopy;
+
+ if (bytesCopied < bytes) // did not get all the bytes yet
+ {
+ Debug.Assert((int)(currentOffset - startOffset) == buffer.Length);
+ // flush buffer
+ Flush();
+ }
+ }
+ Debug.Assert(bytesCopied == bytes);
+ return bytesCopied;
+ }
+
+ ///
+ /// Flushes the internal buffer to the filestream
+ ///
+ public void Flush()
+ {
+ // Make sure that we're initialized before performing operations
+ if (buffer == null)
+ {
+ throw new InvalidOperationException("FileStreamWrapper must be initialized before performing operations");
+ }
+ if (!fileStream.CanWrite)
+ {
+ throw new InvalidOperationException("This FileStreamWrapper cannot be used for writing");
+ }
+
+ // Make sure we are at the right place in the file
+ Debug.Assert(fileStream.Position == startOffset);
+
+ int bytesToWrite = (int)(currentOffset - startOffset);
+ fileStream.Write(buffer, 0, bytesToWrite);
+ startOffset += bytesToWrite;
+ fileStream.Flush();
+
+ Debug.Assert(startOffset == currentOffset);
+ }
+
+ ///
+ /// Deletes the given file (ideally, created with this wrapper) from the filesystem
+ ///
+ /// The path to the file to delete
+ public static void DeleteFile(string fileName)
+ {
+ File.Delete(fileName);
+ }
+
+ #endregion
+
+ ///
+ /// Perform calculations to determine how many bytes to copy and what the new buffer offset
+ /// will be for copying.
+ ///
+ /// Number of bytes requested to copy
+ /// Number of bytes copied so far
+ /// New offset to start copying from/to
+ /// Number of bytes to copy in this iteration
+ private void GetByteCounts(int bytes, int bytesCopied, out int bufferOffset, out int bytesToCopy)
+ {
+ bufferOffset = (int) (currentOffset - startOffset);
+ bytesToCopy = bytes - bytesCopied;
+ if (bytesToCopy > buffer.Length - bufferOffset)
+ {
+ bytesToCopy = buffer.Length - bufferOffset;
+ }
+ }
+
+ ///
+ /// Moves the internal buffer to the specified offset into the file
+ ///
+ /// Offset into the file to move to
+ private void MoveTo(long offset)
+ {
+ if (buffer.Length > bufferDataSize || // buffer is not completely filled
+ offset < startOffset || // before current buffer start
+ offset >= (startOffset + buffer.Length)) // beyond current buffer end
+ {
+ // init the offset
+ startOffset = offset;
+
+ // position file pointer
+ fileStream.Seek(startOffset, SeekOrigin.Begin);
+
+ // fill in the buffer
+ bufferDataSize = fileStream.Read(buffer, 0, buffer.Length);
+ }
+ // make sure to record where we are
+ currentOffset = offset;
+ }
+
+ #region IDisposable Implementation
+
+ private bool disposed;
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposed)
+ {
+ return;
+ }
+
+ if (disposing && fileStream != null)
+ {
+ if(fileStream.CanWrite) { Flush(); }
+ fileStream.Dispose();
+ }
+
+ disposed = true;
+ }
+
+ ~FileStreamWrapper()
+ {
+ Dispose(false);
+ }
+
+ #endregion
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs
new file mode 100644
index 00000000..6cb50095
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamFactory.cs
@@ -0,0 +1,22 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Interface for a factory that creates filesystem readers/writers
+ ///
+ public interface IFileStreamFactory
+ {
+ string CreateFile();
+
+ IFileStreamReader GetReader(string fileName);
+
+ IFileStreamWriter GetWriter(string fileName, int maxCharsToStore, int maxXmlCharsToStore);
+
+ void DisposeFile(string fileName);
+
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs
new file mode 100644
index 00000000..ea5584f1
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamReader.cs
@@ -0,0 +1,35 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Interface for a object that reads from the filesystem
+ ///
+ public interface IFileStreamReader : IDisposable
+ {
+ object[] ReadRow(long offset, IEnumerable columns);
+ FileStreamReadResult ReadInt16(long i64Offset);
+ FileStreamReadResult ReadInt32(long i64Offset);
+ FileStreamReadResult ReadInt64(long i64Offset);
+ FileStreamReadResult ReadByte(long i64Offset);
+ FileStreamReadResult ReadChar(long i64Offset);
+ FileStreamReadResult ReadBoolean(long i64Offset);
+ FileStreamReadResult ReadSingle(long i64Offset);
+ FileStreamReadResult ReadDouble(long i64Offset);
+ FileStreamReadResult ReadSqlDecimal(long i64Offset);
+ FileStreamReadResult ReadDecimal(long i64Offset);
+ FileStreamReadResult ReadDateTime(long i64Offset);
+ FileStreamReadResult ReadTimeSpan(long i64Offset);
+ FileStreamReadResult ReadString(long i64Offset);
+ FileStreamReadResult ReadBytes(long i64Offset);
+ FileStreamReadResult ReadDateTimeOffset(long i64Offset);
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs
new file mode 100644
index 00000000..38c283c5
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWrapper.cs
@@ -0,0 +1,22 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.IO;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Interface for a wrapper around a filesystem reader/writer, mainly for unit testing purposes
+ ///
+ public interface IFileStreamWrapper : IDisposable
+ {
+ void Init(string fileName, int bufferSize, FileAccess fileAccessMode);
+ int ReadData(byte[] buffer, int bytes);
+ int ReadData(byte[] buffer, int bytes, long fileOffset);
+ int WriteData(byte[] buffer, int bytes);
+ void Flush();
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs
new file mode 100644
index 00000000..968701ed
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/IFileStreamWriter.cs
@@ -0,0 +1,35 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Data.SqlTypes;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Interface for a object that writes to a filesystem wrapper
+ ///
+ public interface IFileStreamWriter : IDisposable
+ {
+ int WriteRow(StorageDataReader dataReader);
+ int WriteNull();
+ int WriteInt16(short val);
+ int WriteInt32(int val);
+ int WriteInt64(long val);
+ int WriteByte(byte val);
+ int WriteChar(char val);
+ int WriteBoolean(bool val);
+ int WriteSingle(float val);
+ int WriteDouble(double val);
+ int WriteDecimal(decimal val);
+ int WriteSqlDecimal(SqlDecimal val);
+ int WriteDateTime(DateTime val);
+ int WriteDateTimeOffset(DateTimeOffset dtoVal);
+ int WriteTimeSpan(TimeSpan val);
+ int WriteString(string val);
+ int WriteBytes(byte[] bytes, int length);
+ void FlushBuffer();
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs
new file mode 100644
index 00000000..c06a13ac
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamFactory.cs
@@ -0,0 +1,64 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System.IO;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Factory that creates file reader/writers that process rows in an internal, non-human readable file format
+ ///
+ public class ServiceBufferFileStreamFactory : IFileStreamFactory
+ {
+ ///
+ /// Creates a new temporary file
+ ///
+ /// The name of the temporary file
+ public string CreateFile()
+ {
+ return Path.GetTempFileName();
+ }
+
+ ///
+ /// Creates a new for reading values back from
+ /// an SSMS formatted buffer file
+ ///
+ /// The file to read values from
+ /// A
+ public IFileStreamReader GetReader(string fileName)
+ {
+ return new ServiceBufferFileStreamReader(new FileStreamWrapper(), fileName);
+ }
+
+ ///
+ /// Creates a new for writing values out to an
+ /// SSMS formatted buffer file
+ ///
+ /// The file to write values to
+ /// The maximum number of characters to store from long text fields
+ /// The maximum number of characters to store from xml fields
+ /// A
+ public IFileStreamWriter GetWriter(string fileName, int maxCharsToStore, int maxXmlCharsToStore)
+ {
+ return new ServiceBufferFileStreamWriter(new FileStreamWrapper(), fileName, maxCharsToStore, maxXmlCharsToStore);
+ }
+
+ ///
+ /// Disposes of a file created via this factory
+ ///
+ /// The file to dispose of
+ public void DisposeFile(string fileName)
+ {
+ try
+ {
+ FileStreamWrapper.DeleteFile(fileName);
+ }
+ catch
+ {
+ // If we have problems deleting the file from a temp location, we don't really care
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs
new file mode 100644
index 00000000..0cfc2466
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/DataStorage/ServiceBufferFileStreamReader.cs
@@ -0,0 +1,889 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Diagnostics;
+using System.IO;
+using System.Text;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage
+{
+ ///
+ /// Reader for SSMS formatted file streams
+ ///
+ public class ServiceBufferFileStreamReader : IFileStreamReader
+ {
+ // Most of this code is based on code from the Microsoft.SqlServer.Management.UI.Grid, SSMS DataStorage
+ // $\Data Tools\SSMS_XPlat\sql\ssms\core\DataStorage\src\FileStreamReader.cs
+
+ private const int DefaultBufferSize = 8192;
+
+ #region Member Variables
+
+ private byte[] buffer;
+
+ private readonly IFileStreamWrapper fileStream;
+
+ #endregion
+
+ ///
+ /// Constructs a new ServiceBufferFileStreamReader and initializes its state
+ ///
+ /// The filestream wrapper to read from
+ /// The name of the file to read from
+ public ServiceBufferFileStreamReader(IFileStreamWrapper fileWrapper, string fileName)
+ {
+ // Open file for reading/writing
+ fileStream = fileWrapper;
+ fileStream.Init(fileName, DefaultBufferSize, FileAccess.Read);
+
+ // Create internal buffer
+ buffer = new byte[DefaultBufferSize];
+ }
+
+ #region IFileStreamStorage Implementation
+
+ ///
+ /// Reads a row from the file, based on the columns provided
+ ///
+ /// Offset into the file where the row starts
+ /// The columns that were encoded
+ /// The objects from the row
+ public object[] ReadRow(long fileOffset, IEnumerable columns)
+ {
+ // Initialize for the loop
+ long currentFileOffset = fileOffset;
+ List
public class Query : IDisposable
{
- #region Constants
-
///
/// "Error" code produced by SQL Server when the database context (name) for a connection changes.
///
private const int DatabaseContextChangeErrorNumber = 5701;
+ #region Member Variables
+
+ ///
+ /// Cancellation token source, used for cancelling async db actions
+ ///
+ private readonly CancellationTokenSource cancellationSource;
+
+ ///
+ /// For IDisposable implementation, whether or not this object has been disposed
+ ///
+ private bool disposed;
+
+ ///
+ /// The connection info associated with the file editor owner URI, used to create a new
+ /// connection upon execution of the query
+ ///
+ private readonly ConnectionInfo editorConnection;
+
+ ///
+ /// Whether or not the execute method has been called for this query
+ ///
+ private bool hasExecuteBeenCalled;
+
+ ///
+ /// The factory to use for outputting the results of this query
+ ///
+ private readonly IFileStreamFactory outputFileFactory;
+
#endregion
+ ///
+ /// Constructor for a query
+ ///
+ /// The text of the query to execute
+ /// The information of the connection to use to execute the query
+ /// Settings for how to execute the query, from the user
+ /// Factory for creating output files
+ public Query(string queryText, ConnectionInfo connection, QueryExecutionSettings settings, IFileStreamFactory outputFactory)
+ {
+ // Sanity check for input
+ if (string.IsNullOrEmpty(queryText))
+ {
+ throw new ArgumentNullException(nameof(queryText), "Query text cannot be null");
+ }
+ if (connection == null)
+ {
+ throw new ArgumentNullException(nameof(connection), "Connection cannot be null");
+ }
+ if (settings == null)
+ {
+ throw new ArgumentNullException(nameof(settings), "Settings cannot be null");
+ }
+ if (outputFactory == null)
+ {
+ throw new ArgumentNullException(nameof(outputFactory), "Output file factory cannot be null");
+ }
+
+ // Initialize the internal state
+ QueryText = queryText;
+ editorConnection = connection;
+ cancellationSource = new CancellationTokenSource();
+ outputFileFactory = outputFactory;
+
+ // Process the query into batches
+ ParseResult parseResult = Parser.Parse(queryText, new ParseOptions
+ {
+ BatchSeparator = settings.BatchSeparator
+ });
+ // NOTE: We only want to process batches that have statements (ie, ignore comments and empty lines)
+ Batches = parseResult.Script.Batches.Where(b => b.Statements.Count > 0)
+ .Select(b => new Batch(b.Sql, b.StartLocation.LineNumber, outputFileFactory)).ToArray();
+ }
+
#region Properties
///
@@ -59,19 +128,6 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
}
- ///
- /// Cancellation token source, used for cancelling async db actions
- ///
- private readonly CancellationTokenSource cancellationSource;
-
- ///
- /// The connection info associated with the file editor owner URI, used to create a new
- /// connection upon execution of the query
- ///
- private ConnectionInfo EditorConnection { get; set; }
-
- private bool HasExecuteBeenCalled { get; set; }
-
///
/// Whether or not the query has completed executed, regardless of success or failure
///
@@ -80,10 +136,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
public bool HasExecuted
{
- get { return Batches.Length == 0 ? HasExecuteBeenCalled : Batches.All(b => b.HasExecuted); }
+ get { return Batches.Length == 0 ? hasExecuteBeenCalled : Batches.All(b => b.HasExecuted); }
internal set
{
- HasExecuteBeenCalled = value;
+ hasExecuteBeenCalled = value;
foreach (var batch in Batches)
{
batch.HasExecuted = value;
@@ -98,41 +154,21 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
#endregion
+ #region Public Methods
+
///
- /// Constructor for a query
+ /// Cancels the query by issuing the cancellation token
///
- /// The text of the query to execute
- /// The information of the connection to use to execute the query
- /// Settings for how to execute the query, from the user
- public Query(string queryText, ConnectionInfo connection, QueryExecutionSettings settings)
+ public void Cancel()
{
- // Sanity check for input
- if (string.IsNullOrEmpty(queryText))
+ // Make sure that the query hasn't completed execution
+ if (HasExecuted)
{
- throw new ArgumentNullException(nameof(queryText), "Query text cannot be null");
- }
- if (connection == null)
- {
- throw new ArgumentNullException(nameof(connection), "Connection cannot be null");
- }
- if (settings == null)
- {
- throw new ArgumentNullException(nameof(settings), "Settings cannot be null");
+ throw new InvalidOperationException("The query has already completed, it cannot be cancelled.");
}
- // Initialize the internal state
- QueryText = queryText;
- EditorConnection = connection;
- cancellationSource = new CancellationTokenSource();
-
- // Process the query into batches
- ParseResult parseResult = Parser.Parse(queryText, new ParseOptions
- {
- BatchSeparator = settings.BatchSeparator
- });
- // NOTE: We only want to process batches that have statements (ie, ignore comments and empty lines)
- Batches = parseResult.Script.Batches.Where(b => b.Statements.Count > 0)
- .Select(b => new Batch(b.Sql, b.StartLocation.LineNumber)).ToArray();
+ // Issue the cancellation token for the query
+ cancellationSource.Cancel();
}
///
@@ -141,7 +177,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
public async Task Execute()
{
// Mark that we've internally executed
- HasExecuteBeenCalled = true;
+ hasExecuteBeenCalled = true;
// Don't actually execute if there aren't any batches to execute
if (Batches.Length == 0)
@@ -150,8 +186,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
}
// Open up a connection for querying the database
- string connectionString = ConnectionService.BuildConnectionString(EditorConnection.ConnectionDetails);
- using (DbConnection conn = EditorConnection.Factory.CreateSqlConnection(connectionString))
+ string connectionString = ConnectionService.BuildConnectionString(editorConnection.ConnectionDetails);
+ // TODO: Don't create a new connection every time, see TFS #834978
+ using (DbConnection conn = editorConnection.Factory.CreateSqlConnection(connectionString))
{
await conn.OpenAsync();
@@ -167,6 +204,8 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
await b.Execute(conn, cancellationSource.Token);
}
+
+ // TODO: Close connection after eliminating using statement for above TODO
}
}
@@ -176,13 +215,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
private void OnInfoMessage(object sender, SqlInfoMessageEventArgs args)
{
SqlConnection conn = sender as SqlConnection;
+ if (conn == null)
+ {
+ throw new InvalidOperationException("Sender for OnInfoMessage event must be a SqlConnection");
+ }
foreach(SqlError error in args.Errors)
{
// Did the database context change (error code 5701)?
if (error.Number == DatabaseContextChangeErrorNumber)
{
- ConnectionService.Instance.ChangeConnectionDatabaseContext(EditorConnection.OwnerUri, conn.Database);
+ ConnectionService.Instance.ChangeConnectionDatabaseContext(editorConnection.OwnerUri, conn.Database);
}
}
}
@@ -195,7 +238,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// The starting row of the results
/// How many rows to retrieve
/// A subset of results
- public ResultSetSubset GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount)
+ public Task GetSubset(int batchIndex, int resultSetIndex, int startRow, int rowCount)
{
// Sanity check that the results are available
if (!HasExecuted)
@@ -213,25 +256,10 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
return Batches[batchIndex].GetSubset(resultSetIndex, startRow, rowCount);
}
- ///
- /// Cancels the query by issuing the cancellation token
- ///
- public void Cancel()
- {
- // Make sure that the query hasn't completed execution
- if (HasExecuted)
- {
- throw new InvalidOperationException("The query has already completed, it cannot be cancelled.");
- }
-
- // Issue the cancellation token for the query
- cancellationSource.Cancel();
- }
+ #endregion
#region IDisposable Implementation
- private bool disposed;
-
public void Dispose()
{
Dispose(true);
@@ -248,16 +276,15 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
if (disposing)
{
cancellationSource.Dispose();
+ foreach (Batch b in Batches)
+ {
+ b.Dispose();
+ }
}
disposed = true;
}
- ~Query()
- {
- Dispose(false);
- }
-
#endregion
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
index f23925ab..97d89fc9 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
@@ -10,6 +10,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Hosting;
using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
@@ -24,6 +25,9 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
private static readonly Lazy instance = new Lazy(() => new QueryExecutionService());
+ ///
+ /// Singleton instance of the query execution service
+ ///
public static QueryExecutionService Instance
{
get { return instance.Value; }
@@ -43,6 +47,22 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
#region Properties
+ ///
+ /// File factory to be used to create a buffer file for results.
+ ///
+ ///
+ /// Made internal here to allow for overriding in unit testing
+ ///
+ internal IFileStreamFactory BufferFileStreamFactory;
+
+ ///
+ /// File factory to be used to create a buffer file for results
+ ///
+ private IFileStreamFactory BufferFileFactory
+ {
+ get { return BufferFileStreamFactory ?? (BufferFileStreamFactory = new ServiceBufferFileStreamFactory()); }
+ }
+
///
/// The collection of active queries
///
@@ -134,7 +154,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
var result = new QueryExecuteSubsetResult
{
Message = null,
- ResultSubset = query.GetSubset(subsetParams.BatchIndex,
+ ResultSubset = await query.GetSubset(subsetParams.BatchIndex,
subsetParams.ResultSetIndex, subsetParams.RowsStartIndex, subsetParams.RowsCount)
};
await requestContext.SendResult(result);
@@ -266,7 +286,7 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
QueryExecutionSettings settings = WorkspaceService.Instance.CurrentSettings.QueryExecutionSettings;
// If we can't add the query now, it's assumed the query is in progress
- Query newQuery = new Query(executeParams.QueryText, connectionInfo, settings);
+ Query newQuery = new Query(executeParams.QueryText, connectionInfo, settings, BufferFileFactory);
if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery))
{
await requestContext.SendResult(new QueryExecuteResult
diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
index 058dc54c..84e18c99 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs
@@ -1,4 +1,4 @@
-//
+//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
@@ -7,46 +7,130 @@ using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
+using Microsoft.SqlTools.ServiceLayer.Utility;
namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
{
- public class ResultSet
+ public class ResultSet : IDisposable
{
- public DbColumn[] Columns { get; set; }
+ #region Constants
- public List Rows { get; private set; }
+ private const int DefaultMaxCharsToStore = 65535; // 64 KB - QE default
- public ResultSet()
- {
- Rows = new List();
- }
+ // xml is a special case so number of chars to store is usually greater than for other long types
+ private const int DefaultMaxXmlCharsToStore = 2097152; // 2 MB - QE default
+
+ #endregion
+
+ #region Member Variables
///
- /// Add a row of data to the result set using a that has already
- /// read in a row.
+ /// For IDisposable pattern, whether or not object has been disposed
///
- /// A that has already had a read performed
- public void AddRow(DbDataReader reader)
+ private bool disposed;
+
+ ///
+ /// The factory to use to get reading/writing handlers
+ ///
+ private readonly IFileStreamFactory fileStreamFactory;
+
+ ///
+ /// File stream reader that will be reused to make rapid-fire retrieval of result subsets
+ /// quick and low perf impact.
+ ///
+ private IFileStreamReader fileStreamReader;
+
+ ///
+ /// Whether or not the result set has been read in from the database
+ ///
+ private bool hasBeenRead;
+
+ ///
+ /// The name of the temporary file we're using to output these results in
+ ///
+ private readonly string outputFileName;
+
+ #endregion
+
+ ///
+ /// Creates a new result set and initializes its state
+ ///
+ /// The reader from executing a query
+ /// Factory for creating a reader/writer
+ public ResultSet(DbDataReader reader, IFileStreamFactory factory)
{
- List row = new List();
- for (int i = 0; i < reader.FieldCount; ++i)
+ // Sanity check to make sure we got a reader
+ if (reader == null)
{
- row.Add(reader.GetValue(i));
+ throw new ArgumentNullException(nameof(reader), "Reader cannot be null");
}
- Rows.Add(row.ToArray());
+ DataReader = new StorageDataReader(reader);
+
+ // Initialize the storage
+ outputFileName = factory.CreateFile();
+ FileOffsets = new LongList();
+
+ // Store the factory
+ fileStreamFactory = factory;
+ hasBeenRead = false;
}
+ #region Properties
+
+ ///
+ /// The columns for this result set
+ ///
+ public DbColumnWrapper[] Columns { get; private set; }
+
+ ///
+ /// The reader to use for this resultset
+ ///
+ private StorageDataReader DataReader { get; set; }
+
+ ///
+ /// A list of offsets into the buffer file that correspond to where rows start
+ ///
+ private LongList FileOffsets { get; set; }
+
+ ///
+ /// Maximum number of characters to store for a field
+ ///
+ public int MaxCharsToStore { get { return DefaultMaxCharsToStore; } }
+
+ ///
+ /// Maximum number of characters to store for an XML field
+ ///
+ public int MaxXmlCharsToStore { get { return DefaultMaxXmlCharsToStore; } }
+
+ ///
+ /// The number of rows for this result set
+ ///
+ public long RowCount { get; private set; }
+
+ #endregion
+
+ #region Public Methods
+
///
/// Generates a subset of the rows from the result set
///
/// The starting row of the results
/// How many rows to retrieve
/// A subset of results
- public ResultSetSubset GetSubset(int startRow, int rowCount)
+ public Task GetSubset(int startRow, int rowCount)
{
+ // Sanity check to make sure that the results have been read beforehand
+ if (!hasBeenRead || fileStreamReader == null)
+ {
+ throw new InvalidOperationException("Cannot read subset unless the results have been read from the server");
+ }
+
// Sanity check to make sure that the row and the row count are within bounds
- if (startRow < 0 || startRow >= Rows.Count)
+ if (startRow < 0 || startRow >= RowCount)
{
throw new ArgumentOutOfRangeException(nameof(startRow), "Start row cannot be less than 0 " +
"or greater than the number of rows in the resultset");
@@ -56,13 +140,79 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
throw new ArgumentOutOfRangeException(nameof(rowCount), "Row count must be a positive integer");
}
- // Retrieve the subset of the results as per the request
- object[][] rows = Rows.Skip(startRow).Take(rowCount).ToArray();
- return new ResultSetSubset
+ return Task.Factory.StartNew(() =>
{
- Rows = rows,
- RowCount = rows.Length
- };
+ // Figure out which rows we need to read back
+ IEnumerable rowOffsets = FileOffsets.Skip(startRow).Take(rowCount);
+
+ // Iterate over the rows we need and process them into output
+ object[][] rows = rowOffsets.Select(rowOffset => fileStreamReader.ReadRow(rowOffset, Columns)).ToArray();
+
+ // Retrieve the subset of the results as per the request
+ return new ResultSetSubset
+ {
+ Rows = rows,
+ RowCount = rows.Length
+ };
+ });
}
+
+ ///
+ /// Reads from the reader until there are no more results to read
+ ///
+ /// Cancellation token for cancelling the query
+ public async Task ReadResultToEnd(CancellationToken cancellationToken)
+ {
+ // Open a writer for the file
+ using (IFileStreamWriter fileWriter = fileStreamFactory.GetWriter(outputFileName, MaxCharsToStore, MaxXmlCharsToStore))
+ {
+ // If we can initialize the columns using the column schema, use that
+ if (!DataReader.DbDataReader.CanGetColumnSchema())
+ {
+ throw new InvalidOperationException("Could not retrieve column schema for result set.");
+ }
+ Columns = DataReader.Columns;
+ long currentFileOffset = 0;
+
+ while (await DataReader.ReadAsync(cancellationToken))
+ {
+ RowCount++;
+ FileOffsets.Add(currentFileOffset);
+ currentFileOffset += fileWriter.WriteRow(DataReader);
+ }
+ }
+
+ // Mark that result has been read
+ hasBeenRead = true;
+ fileStreamReader = fileStreamFactory.GetReader(outputFileName);
+ }
+
+ #endregion
+
+ #region IDisposable Implementation
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposed)
+ {
+ return;
+ }
+
+ if (disposing)
+ {
+ fileStreamReader?.Dispose();
+ fileStreamFactory.DisposeFile(outputFileName);
+ }
+
+ disposed = true;
+ }
+
+ #endregion
}
}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs
new file mode 100644
index 00000000..afacc98f
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/LongList.cs
@@ -0,0 +1,259 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+
+namespace Microsoft.SqlTools.ServiceLayer.Utility
+{
+ ///
+ /// Collection class that permits storage of over int.MaxValue items. This is performed
+ /// by using a 2D list of lists. The internal lists are only initialized as necessary. This
+ /// collection implements IEnumerable to make it easier to run LINQ queries against it.
+ ///
+ ///
+ /// This class is based on code from $\Data Tools\SSMS_Main\sql\ssms\core\DataStorage\ArrayList64.cs
+ /// with additions to bring it up to .NET 4.5 standards
+ ///
+ /// Type of the values to store
+ public class LongList : IEnumerable
+ {
+ #region Member Variables
+
+ private List> expandedList;
+ private readonly List shortList;
+
+ #endregion
+
+ ///
+ /// Creates a new long list
+ ///
+ public LongList()
+ {
+ shortList = new List();
+ Count = 0;
+ }
+
+ #region Properties
+
+ ///
+ /// The total number of elements in the array
+ ///
+ public long Count { get; private set; }
+
+ public T this[long index]
+ {
+ get { return GetItem(index); }
+ }
+
+ #endregion
+
+ #region Public Methods
+
+ ///
+ /// Adds the specified value to the end of the list
+ ///
+ /// Value to add to the list
+ /// Index of the item that was just added
+ public long Add(T val)
+ {
+ if (Count <= int.MaxValue)
+ {
+ shortList.Add(val);
+ }
+ else // need to split values into several arrays
+ {
+ if (expandedList == null)
+ {
+ // very inefficient so delay as much as possible
+ // immediately add 0th array
+ expandedList = new List> {shortList};
+ }
+
+ int arrayIndex = (int)(Count/int.MaxValue); // 0 based
+
+ List arr;
+ if (expandedList.Count <= arrayIndex) // need to make a new array
+ {
+ arr = new List();
+ expandedList.Add(arr);
+ }
+ else // use existing array
+ {
+ arr = expandedList[arrayIndex];
+ }
+ arr.Add(val);
+ }
+ return (++Count);
+ }
+
+ ///
+ /// Returns the item at the specified index
+ ///
+ /// Index of the item to return
+ /// The item at the index specified
+ public T GetItem(long index)
+ {
+ T val = default(T);
+
+ if (Count <= int.MaxValue)
+ {
+ int i32Index = Convert.ToInt32(index);
+ val = shortList[i32Index];
+ }
+ else
+ {
+ int iArray32Index = (int) (Count/int.MaxValue);
+ if (expandedList.Count > iArray32Index)
+ {
+ List arr = expandedList[iArray32Index];
+
+ int i32Index = (int) (Count%int.MaxValue);
+ if (arr.Count > i32Index)
+ {
+ val = arr[i32Index];
+ }
+ }
+ }
+ return val;
+ }
+
+ ///
+ /// Removes an item at the specified location and shifts all the items after the provided
+ /// index up by one.
+ ///
+ /// The index to remove from the list
+ public void RemoveAt(long index)
+ {
+ if (Count <= int.MaxValue)
+ {
+ int iArray32MemberIndex = Convert.ToInt32(index); // 0 based
+ shortList.RemoveAt(iArray32MemberIndex);
+ }
+ else // handle the case of multiple arrays
+ {
+ // find out which array it is in
+ int arrayIndex = (int) (index/int.MaxValue);
+ List arr = expandedList[arrayIndex];
+
+ // find out index into this array
+ int iArray32MemberIndex = (int) (index%int.MaxValue);
+ arr.RemoveAt(iArray32MemberIndex);
+
+ // now shift members of the array back one
+ int iArray32TotalIndex = (int) (Count/Int32.MaxValue);
+ for (int i = arrayIndex + 1; i < iArray32TotalIndex; i++)
+ {
+ List arr1 = expandedList[i - 1];
+ List arr2 = expandedList[i];
+
+ arr1.Add(arr2[int.MaxValue - 1]);
+ arr2.RemoveAt(0);
+ }
+ }
+ --Count;
+ }
+
+ #endregion
+
+ #region IEnumerable Implementation
+
+ ///
+ /// Returns a generic enumerator for enumeration of this LongList
+ ///
+ /// Enumerator for LongList
+ public IEnumerator GetEnumerator()
+ {
+ return new LongListEnumerator(this);
+ }
+
+ ///
+ /// Returns an enumerator for enumeration of this LongList
+ ///
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return GetEnumerator();
+ }
+
+ #endregion
+
+ public class LongListEnumerator : IEnumerator
+ {
+ #region Member Variables
+
+ ///
+ /// The index into the list of the item that is the current item
+ ///
+ private long index;
+
+ ///
+ /// The current list that we're iterating over.
+ ///
+ private readonly LongList localList;
+
+ #endregion
+
+ ///
+ /// Constructs a new enumerator for a given LongList
+ ///
+ /// The list to enumerate
+ public LongListEnumerator(LongList list)
+ {
+ localList = list;
+ index = 0;
+ Current = default(TEt);
+ }
+
+ #region IEnumerator Implementation
+
+ ///
+ /// Returns the current item in the enumeration
+ ///
+ public TEt Current { get; private set; }
+
+ object IEnumerator.Current
+ {
+ get { return Current; }
+ }
+
+ ///
+ /// Moves to the next item in the list we're iterating over
+ ///
+ /// Whether or not the move was successful
+ public bool MoveNext()
+ {
+ if (index < localList.Count)
+ {
+ Current = localList[index];
+ index++;
+ return true;
+ }
+ Current = default(TEt);
+ return false;
+ }
+
+ ///
+ /// Resets the enumeration
+ ///
+ public void Reset()
+ {
+ index = 0;
+ Current = default(TEt);
+ }
+
+ ///
+ /// Disposal method. Does nothing.
+ ///
+ public void Dispose()
+ {
+ }
+
+ #endregion
+ }
+ }
+}
+
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs
index 2437dc30..c7a00c09 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs
@@ -3,9 +3,11 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
+using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
+using System.IO;
using System.Data.SqlClient;
using System.Threading;
using Microsoft.SqlTools.ServiceLayer.Connection;
@@ -16,6 +18,7 @@ using Microsoft.SqlServer.Management.SqlParser.Binder;
using Microsoft.SqlServer.Management.SqlParser.MetadataProvider;
using Microsoft.SqlTools.ServiceLayer.LanguageServices;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Test.Utility;
using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts;
@@ -71,7 +74,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
public static Batch GetBasicExecutedBatch()
{
- Batch batch = new Batch(StandardQuery, 1);
+ Batch batch = new Batch(StandardQuery, 1, GetFileStreamFactory());
batch.Execute(CreateTestConnection(new[] {StandardTestData}, false), CancellationToken.None).Wait();
return batch;
}
@@ -79,11 +82,78 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
public static Query GetBasicExecutedQuery()
{
ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false);
- Query query = new Query(StandardQuery, ci, new QueryExecutionSettings());
+ Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory());
query.Execute().Wait();
return query;
}
+ #region FileStreamWriteMocking
+
+ public static IFileStreamFactory GetFileStreamFactory()
+ {
+ Mock mock = new Mock();
+ mock.Setup(fsf => fsf.GetReader(It.IsAny()))
+ .Returns(new ServiceBufferFileStreamReader(new InMemoryWrapper(), It.IsAny()));
+ mock.Setup(fsf => fsf.GetWriter(It.IsAny(), It.IsAny(), It.IsAny()))
+ .Returns(new ServiceBufferFileStreamWriter(new InMemoryWrapper(), It.IsAny(), 1024,
+ 1024));
+
+ return mock.Object;
+ }
+
+ public class InMemoryWrapper : IFileStreamWrapper
+ {
+ private readonly byte[] storage = new byte[8192];
+ private readonly MemoryStream memoryStream;
+ private bool readingOnly;
+
+ public InMemoryWrapper()
+ {
+ memoryStream = new MemoryStream(storage);
+ }
+
+ public void Dispose()
+ {
+ // We'll dispose this via a special method
+ }
+
+ public void Init(string fileName, int bufferSize, FileAccess fAccess)
+ {
+ readingOnly = fAccess == FileAccess.Read;
+ }
+
+ public int ReadData(byte[] buffer, int bytes)
+ {
+ return ReadData(buffer, bytes, memoryStream.Position);
+ }
+
+ public int ReadData(byte[] buffer, int bytes, long fileOffset)
+ {
+ memoryStream.Seek(fileOffset, SeekOrigin.Begin);
+ return memoryStream.Read(buffer, 0, bytes);
+ }
+
+ public int WriteData(byte[] buffer, int bytes)
+ {
+ if (readingOnly) { throw new InvalidOperationException(); }
+ memoryStream.Write(buffer, 0, bytes);
+ memoryStream.Flush();
+ return bytes;
+ }
+
+ public void Flush()
+ {
+ if (readingOnly) { throw new InvalidOperationException(); }
+ }
+
+ public void Close()
+ {
+ memoryStream.Dispose();
+ }
+ }
+
+ #endregion
+
#region DbConnection Mocking
public static DbCommand CreateTestCommand(Dictionary[][] data, bool throwOnRead)
@@ -151,12 +221,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
out ConnectionInfo connInfo
)
{
- textDocument = new TextDocumentPosition();
- textDocument.TextDocument = new TextDocumentIdentifier();
- textDocument.TextDocument.Uri = Common.OwnerUri;
- textDocument.Position = new Position();
- textDocument.Position.Line = 0;
- textDocument.Position.Character = 0;
+ textDocument = new TextDocumentPosition
+ {
+ TextDocument = new TextDocumentIdentifier {Uri = OwnerUri},
+ Position = new Position
+ {
+ Line = 0,
+ Character = 0
+ }
+ };
connInfo = Common.CreateTestConnectionInfo(null, false);
@@ -166,15 +239,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
var binder = BinderProvider.CreateBinder(metadataProvider);
LanguageService.Instance.ScriptParseInfoMap.Add(textDocument.TextDocument.Uri,
- new ScriptParseInfo()
+ new ScriptParseInfo
{
Binder = binder,
MetadataProvider = metadataProvider,
MetadataDisplayInfoProvider = displayInfoProvider
});
- scriptFile = new ScriptFile();
- scriptFile.ClientFilePath = textDocument.TextDocument.Uri;
+ scriptFile = new ScriptFile {ClientFilePath = textDocument.TextDocument.Uri};
+
}
public static ServerConnection GetServerConnection(ConnectionInfo connection)
@@ -206,7 +279,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
OwnerUri = OwnerUri
});
}
- return new QueryExecutionService(connectionService);
+ return new QueryExecutionService(connectionService) {BufferFileStreamFactory = GetFileStreamFactory()};
}
#endregion
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs
new file mode 100644
index 00000000..f1a4cda0
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/FileStreamWrapperTests.cs
@@ -0,0 +1,221 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
+using Xunit;
+
+namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage
+{
+ public class FileStreamWrapperTests
+ {
+ [Theory]
+ [InlineData(null)]
+ [InlineData("")]
+ [InlineData(" ")]
+ public void InitInvalidFilenameParameter(string fileName)
+ {
+ // If:
+ // ... I have a file stream wrapper that is initialized with invalid fileName
+ // Then:
+ // ... It should throw an argument null exception
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ Assert.Throws(() => fsw.Init(fileName, 8192, FileAccess.Read));
+ }
+ }
+
+ [Theory]
+ [InlineData(0)]
+ [InlineData(-1)]
+ public void InitInvalidBufferLength(int bufferLength)
+ {
+ // If:
+ // ... I have a file stream wrapper that is initialized with an invalid buffer length
+ // Then:
+ // ... I should throw an argument out of range exception
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ Assert.Throws(() => fsw.Init("validFileName", bufferLength, FileAccess.Read));
+ }
+ }
+
+ [Fact]
+ public void InitInvalidFileAccessMode()
+ {
+ // If:
+ // ... I attempt to open a file stream wrapper that is initialized with an invalid file
+ // access mode
+ // Then:
+ // ... I should get an invalid argument exception
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ Assert.Throws(() => fsw.Init("validFileName", 8192, FileAccess.Write));
+ }
+ }
+
+ [Fact]
+ public void InitSuccessful()
+ {
+ string fileName = Path.GetTempFileName();
+
+ try
+ {
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ // If:
+ // ... I have a file stream wrapper that is initialized with valid parameters
+ fsw.Init(fileName, 8192, FileAccess.ReadWrite);
+
+ // Then:
+ // ... The file should exist
+ FileInfo fileInfo = new FileInfo(fileName);
+ Assert.True(fileInfo.Exists);
+
+ // ... The file should be marked as hidden
+ Assert.True((fileInfo.Attributes & FileAttributes.Hidden) != 0);
+ }
+ }
+ finally
+ {
+ // Cleanup:
+ // ... Delete the file that was created
+ try { File.Delete(fileName); } catch { /* Don't care */ }
+ }
+ }
+
+ [Fact]
+ public void PerformOpWithoutInit()
+ {
+ byte[] buf = new byte[10];
+
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ // If:
+ // ... I have a file stream wrapper that hasn't been initialized
+ // Then:
+ // ... Attempting to perform any operation will result in an exception
+ Assert.Throws(() => fsw.ReadData(buf, 1));
+ Assert.Throws(() => fsw.ReadData(buf, 1, 0));
+ Assert.Throws(() => fsw.WriteData(buf, 1));
+ Assert.Throws(() => fsw.Flush());
+ }
+ }
+
+ [Fact]
+ public void PerformWriteOpOnReadOnlyWrapper()
+ {
+ byte[] buf = new byte[10];
+
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ // If:
+ // ... I have a readonly file stream wrapper
+ // Then:
+ // ... Attempting to perform any write operation should result in an exception
+ Assert.Throws(() => fsw.WriteData(buf, 1));
+ Assert.Throws(() => fsw.Flush());
+ }
+ }
+
+ [Theory]
+ [InlineData(1024, 20, 10)] // Standard scenario
+ [InlineData(1024, 100, 100)] // Requested more bytes than there are
+ [InlineData(5, 20, 10)] // Internal buffer too small, force a move-to operation
+ public void ReadData(int internalBufferLength, int outBufferLength, int requestedBytes)
+ {
+ // Setup:
+ // ... I have a file that has a handful of bytes in it
+ string fileName = Path.GetTempFileName();
+ const string stringToWrite = "hello";
+ CreateTestFile(fileName, stringToWrite);
+ byte[] targetBytes = Encoding.Unicode.GetBytes(stringToWrite);
+
+ try
+ {
+ // If:
+ // ... I have a file stream wrapper that has been initialized to an existing file
+ // ... And I read some bytes from it
+ int bytesRead;
+ byte[] buf = new byte[outBufferLength];
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ fsw.Init(fileName, internalBufferLength, FileAccess.Read);
+ bytesRead = fsw.ReadData(buf, targetBytes.Length);
+ }
+
+ // Then:
+ // ... I should get those bytes back
+ Assert.Equal(targetBytes.Length, bytesRead);
+ Assert.True(targetBytes.Take(targetBytes.Length).SequenceEqual(buf.Take(targetBytes.Length)));
+
+ }
+ finally
+ {
+ // Cleanup:
+ // ... Delete the test file
+ CleanupTestFile(fileName);
+ }
+ }
+
+ [Theory]
+ [InlineData(1024)] // Standard scenario
+ [InlineData(10)] // Internal buffer too small, forces a flush
+ public void WriteData(int internalBufferLength)
+ {
+ string fileName = Path.GetTempFileName();
+ byte[] bytesToWrite = Encoding.Unicode.GetBytes("hello");
+
+ try
+ {
+ // If:
+ // ... I have a file stream that has been initialized
+ // ... And I write some bytes to it
+ using (FileStreamWrapper fsw = new FileStreamWrapper())
+ {
+ fsw.Init(fileName, internalBufferLength, FileAccess.ReadWrite);
+ int bytesWritten = fsw.WriteData(bytesToWrite, bytesToWrite.Length);
+
+ Assert.Equal(bytesToWrite.Length, bytesWritten);
+ }
+
+ // Then:
+ // ... The file I wrote to should contain only the bytes I wrote out
+ using (FileStream fs = File.OpenRead(fileName))
+ {
+ byte[] readBackBytes = new byte[1024];
+ int bytesRead = fs.Read(readBackBytes, 0, readBackBytes.Length);
+
+ Assert.Equal(bytesToWrite.Length, bytesRead); // If bytes read is not equal, then more or less of the original string was written to the file
+ Assert.True(bytesToWrite.SequenceEqual(readBackBytes.Take(bytesRead)));
+ }
+ }
+ finally
+ {
+ // Cleanup:
+ // ... Delete the test file
+ CleanupTestFile(fileName);
+ }
+ }
+
+ private static void CreateTestFile(string fileName, string value)
+ {
+ using (FileStream fs = new FileStream(fileName, FileMode.OpenOrCreate, FileAccess.ReadWrite))
+ {
+ byte[] bytesToWrite = Encoding.Unicode.GetBytes(value);
+ fs.Write(bytesToWrite, 0, bytesToWrite.Length);
+ fs.Flush();
+ }
+ }
+
+ private static void CleanupTestFile(string fileName)
+ {
+ try { File.Delete(fileName); } catch { /* Don't Care */}
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs
new file mode 100644
index 00000000..b10a7f92
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DataStorage/ServiceBufferFileStreamReaderWriterTests.cs
@@ -0,0 +1,295 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+using System.Collections.Generic;
+using System.Data.SqlTypes;
+using System.Text;
+using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
+using Xunit;
+
+namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution.DataStorage
+{
+ public class ReaderWriterPairTest
+ {
+ private static void VerifyReadWrite(int valueLength, T value, Func writeFunc, Func> readFunc)
+ {
+ // Setup: Create a mock file stream wrapper
+ Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper();
+ try
+ {
+ // If:
+ // ... I write a type T to the writer
+ using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10))
+ {
+ int writtenBytes = writeFunc(writer, value);
+ Assert.Equal(valueLength, writtenBytes);
+ }
+
+ // ... And read the type T back
+ FileStreamReadResult outValue;
+ using (ServiceBufferFileStreamReader reader = new ServiceBufferFileStreamReader(mockWrapper, "abc"))
+ {
+ outValue = readFunc(reader);
+ }
+
+ // Then:
+ Assert.Equal(value, outValue.Value);
+ Assert.Equal(valueLength, outValue.TotalLength);
+ Assert.False(outValue.IsNull);
+ }
+ finally
+ {
+ // Cleanup: Close the wrapper
+ mockWrapper.Close();
+ }
+ }
+
+ [Theory]
+ [InlineData(0)]
+ [InlineData(10)]
+ [InlineData(-10)]
+ [InlineData(short.MaxValue)] // Two byte number
+ [InlineData(short.MinValue)] // Negative two byte number
+ public void Int16(short value)
+ {
+ VerifyReadWrite(sizeof(short) + 1, value, (writer, val) => writer.WriteInt16(val), reader => reader.ReadInt16(0));
+ }
+
+ [Theory]
+ [InlineData(0)]
+ [InlineData(10)]
+ [InlineData(-10)]
+ [InlineData(short.MaxValue)] // Two byte number
+ [InlineData(short.MinValue)] // Negative two byte number
+ [InlineData(int.MaxValue)] // Four byte number
+ [InlineData(int.MinValue)] // Negative four byte number
+ public void Int32(int value)
+ {
+ VerifyReadWrite(sizeof(int) + 1, value, (writer, val) => writer.WriteInt32(val), reader => reader.ReadInt32(0));
+ }
+
+ [Theory]
+ [InlineData(0)]
+ [InlineData(10)]
+ [InlineData(-10)]
+ [InlineData(short.MaxValue)] // Two byte number
+ [InlineData(short.MinValue)] // Negative two byte number
+ [InlineData(int.MaxValue)] // Four byte number
+ [InlineData(int.MinValue)] // Negative four byte number
+ [InlineData(long.MaxValue)] // Eight byte number
+ [InlineData(long.MinValue)] // Negative eight byte number
+ public void Int64(long value)
+ {
+ VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteInt64(val), reader => reader.ReadInt64(0));
+ }
+
+ [Theory]
+ [InlineData(0)]
+ [InlineData(10)]
+ public void Byte(byte value)
+ {
+ VerifyReadWrite(sizeof(byte) + 1, value, (writer, val) => writer.WriteByte(val), reader => reader.ReadByte(0));
+ }
+
+ [Theory]
+ [InlineData('a')]
+ [InlineData('1')]
+ [InlineData((char)0x9152)] // Test something in the UTF-16 space
+ public void Char(char value)
+ {
+ VerifyReadWrite(sizeof(char) + 1, value, (writer, val) => writer.WriteChar(val), reader => reader.ReadChar(0));
+ }
+
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public void Boolean(bool value)
+ {
+ VerifyReadWrite(sizeof(bool) + 1, value, (writer, val) => writer.WriteBoolean(val), reader => reader.ReadBoolean(0));
+ }
+
+ [Theory]
+ [InlineData(0)]
+ [InlineData(10.1)]
+ [InlineData(-10.1)]
+ [InlineData(float.MinValue)]
+ [InlineData(float.MaxValue)]
+ [InlineData(float.PositiveInfinity)]
+ [InlineData(float.NegativeInfinity)]
+ public void Single(float value)
+ {
+ VerifyReadWrite(sizeof(float) + 1, value, (writer, val) => writer.WriteSingle(val), reader => reader.ReadSingle(0));
+ }
+
+ [Theory]
+ [InlineData(0)]
+ [InlineData(10.1)]
+ [InlineData(-10.1)]
+ [InlineData(float.MinValue)]
+ [InlineData(float.MaxValue)]
+ [InlineData(float.PositiveInfinity)]
+ [InlineData(float.NegativeInfinity)]
+ [InlineData(double.PositiveInfinity)]
+ [InlineData(double.NegativeInfinity)]
+ [InlineData(double.MinValue)]
+ [InlineData(double.MaxValue)]
+ public void Double(double value)
+ {
+ VerifyReadWrite(sizeof(double) + 1, value, (writer, val) => writer.WriteDouble(val), reader => reader.ReadDouble(0));
+ }
+
+ [Fact]
+ public void SqlDecimalTest()
+ {
+ // Setup: Create some test values
+ // NOTE: We are doing these here instead of InlineData because SqlDecimal values can't be written as constant expressions
+ SqlDecimal[] testValues =
+ {
+ SqlDecimal.MaxValue, SqlDecimal.MinValue, new SqlDecimal(0x01, 0x01, true, 0, 0, 0, 0)
+ };
+ foreach (SqlDecimal value in testValues)
+ {
+ int valueLength = 4 + value.BinData.Length;
+ VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteSqlDecimal(val), reader => reader.ReadSqlDecimal(0));
+ }
+ }
+
+ [Fact]
+ public void Decimal()
+ {
+ // Setup: Create some test values
+ // NOTE: We are doing these here instead of InlineData because Decimal values can't be written as constant expressions
+ decimal[] testValues =
+ {
+ decimal.Zero, decimal.One, decimal.MinusOne, decimal.MinValue, decimal.MaxValue
+ };
+
+ foreach (decimal value in testValues)
+ {
+ int valueLength = decimal.GetBits(value).Length*4 + 1;
+ VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteDecimal(val), reader => reader.ReadDecimal(0));
+ }
+ }
+
+ [Fact]
+ public void DateTimeTest()
+ {
+ // Setup: Create some test values
+ // NOTE: We are doing these here instead of InlineData because DateTime values can't be written as constant expressions
+ DateTime[] testValues =
+ {
+ DateTime.Now, DateTime.UtcNow, DateTime.MinValue, DateTime.MaxValue
+ };
+ foreach (DateTime value in testValues)
+ {
+ VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteDateTime(val), reader => reader.ReadDateTime(0));
+ }
+ }
+
+ [Fact]
+ public void DateTimeOffsetTest()
+ {
+ // Setup: Create some test values
+ // NOTE: We are doing these here instead of InlineData because DateTimeOffset values can't be written as constant expressions
+ DateTimeOffset[] testValues =
+ {
+ DateTimeOffset.Now, DateTimeOffset.UtcNow, DateTimeOffset.MinValue, DateTimeOffset.MaxValue
+ };
+ foreach (DateTimeOffset value in testValues)
+ {
+ VerifyReadWrite((sizeof(long) + 1)*2, value, (writer, val) => writer.WriteDateTimeOffset(val), reader => reader.ReadDateTimeOffset(0));
+ }
+ }
+
+ [Fact]
+ public void TimeSpanTest()
+ {
+ // Setup: Create some test values
+ // NOTE: We are doing these here instead of InlineData because TimeSpan values can't be written as constant expressions
+ TimeSpan[] testValues =
+ {
+ TimeSpan.Zero, TimeSpan.MinValue, TimeSpan.MaxValue, TimeSpan.FromMinutes(60)
+ };
+ foreach (TimeSpan value in testValues)
+ {
+ VerifyReadWrite(sizeof(long) + 1, value, (writer, val) => writer.WriteTimeSpan(val), reader => reader.ReadTimeSpan(0));
+ }
+ }
+
+ [Fact]
+ public void StringNullTest()
+ {
+ // Setup: Create a mock file stream wrapper
+ Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper();
+
+ // If:
+ // ... I write null as a string to the writer
+ using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10))
+ {
+ // Then:
+ // ... I should get an argument null exception
+ Assert.Throws(() => writer.WriteString(null));
+ }
+ }
+
+ [Theory]
+ [InlineData(0, null)] // Test of empty string
+ [InlineData(1, new[] { 'j' })]
+ [InlineData(1, new[] { (char)0x9152 })]
+ [InlineData(100, new[] { 'j', (char)0x9152 })] // Test alternating utf-16/ascii characters
+ [InlineData(512, new[] { 'j', (char)0x9152 })] // Test that requires a 4 byte length
+ public void StringTest(int length, char[] values)
+ {
+ // Setup:
+ // ... Generate the test value
+ StringBuilder sb = new StringBuilder();
+ for (int i = 0; i < length; i++)
+ {
+ sb.Append(values[i%values.Length]);
+ }
+ string value = sb.ToString();
+ int lengthLength = length == 0 || length > 255 ? 5 : 1;
+ VerifyReadWrite(sizeof(char)*length + lengthLength, value, (writer, val) => writer.WriteString(value), reader => reader.ReadString(0));
+ }
+
+ [Fact]
+ public void BytesNullTest()
+ {
+ // Setup: Create a mock file stream wrapper
+ Common.InMemoryWrapper mockWrapper = new Common.InMemoryWrapper();
+
+ // If:
+ // ... I write null as a string to the writer
+ using (ServiceBufferFileStreamWriter writer = new ServiceBufferFileStreamWriter(mockWrapper, "abc", 10, 10))
+ {
+ // Then:
+ // ... I should get an argument null exception
+ Assert.Throws(() => writer.WriteBytes(null, 0));
+ }
+ }
+
+ [Theory]
+ [InlineData(0, new byte[] { 0x00 })] // Test of empty byte[]
+ [InlineData(1, new byte[] { 0x00 })]
+ [InlineData(1, new byte[] { 0xFF })]
+ [InlineData(100, new byte[] { 0x10, 0xFF, 0x00 })]
+ [InlineData(512, new byte[] { 0x10, 0xFF, 0x00 })] // Test that requires a 4 byte length
+ public void Bytes(int length, byte[] values)
+ {
+ // Setup:
+ // ... Generate the test value
+ List sb = new List();
+ for (int i = 0; i < length; i++)
+ {
+ sb.Add(values[i % values.Length]);
+ }
+ byte[] value = sb.ToArray();
+ int lengthLength = length == 0 || length > 255 ? 5 : 1;
+ int valueLength = sizeof(byte)*length + lengthLength;
+ VerifyReadWrite(valueLength, value, (writer, val) => writer.WriteBytes(value, length), reader => reader.ReadBytes(0));
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs
index 49b6c76c..bda6ca0d 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs
@@ -29,7 +29,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
public void BatchCreationTest()
{
// If I create a new batch...
- Batch batch = new Batch(Common.StandardQuery, 1);
+ Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory());
// Then:
// ... The text of the batch should be stored
@@ -52,7 +52,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
public void BatchExecuteNoResultSets()
{
// If I execute a query that should get no result sets
- Batch batch = new Batch(Common.StandardQuery, 1);
+ Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory());
batch.Execute(GetConnection(Common.CreateTestConnectionInfo(null, false)), CancellationToken.None).Wait();
// Then:
@@ -79,7 +79,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false);
// If I execute a query that should get one result set
- Batch batch = new Batch(Common.StandardQuery, 1);
+ Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory());
batch.Execute(GetConnection(ci), CancellationToken.None).Wait();
// Then:
@@ -92,11 +92,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
Assert.Equal(resultSets, batch.ResultSummaries.Length);
// ... Inside the result set should be with 5 rows
- Assert.Equal(Common.StandardRows, batch.ResultSets.First().Rows.Count);
+ Assert.Equal(Common.StandardRows, batch.ResultSets.First().RowCount);
Assert.Equal(Common.StandardRows, batch.ResultSummaries[0].RowCount);
- // ... Inside the result set should have 5 columns and 5 column definitions
- Assert.Equal(Common.StandardColumns, batch.ResultSets.First().Rows[0].Length);
+ // ... Inside the result set should have 5 columns
Assert.Equal(Common.StandardColumns, batch.ResultSets.First().Columns.Length);
Assert.Equal(Common.StandardColumns, batch.ResultSummaries[0].ColumnInfo.Length);
@@ -112,7 +111,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
ConnectionInfo ci = Common.CreateTestConnectionInfo(dataset, false);
// If I execute a query that should get two result sets
- Batch batch = new Batch(Common.StandardQuery, 1);
+ Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory());
batch.Execute(GetConnection(ci), CancellationToken.None).Wait();
// Then:
@@ -126,10 +125,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
foreach (ResultSet rs in batch.ResultSets)
{
// ... Each result set should have 5 rows
- Assert.Equal(Common.StandardRows, rs.Rows.Count);
+ Assert.Equal(Common.StandardRows, rs.RowCount);
- // ... Inside each result set should be 5 columns and 5 column definitions
- Assert.Equal(Common.StandardColumns, rs.Rows[0].Length);
+ // ... Inside each result set should be 5 columns
Assert.Equal(Common.StandardColumns, rs.Columns.Length);
}
@@ -155,7 +153,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true);
// If I execute a batch that is invalid
- Batch batch = new Batch(Common.StandardQuery, 1);
+ Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory());
batch.Execute(GetConnection(ci), CancellationToken.None).Wait();
// Then:
@@ -177,7 +175,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] { Common.StandardTestData }, false);
// If I execute a batch
- Batch batch = new Batch(Common.StandardQuery, 1);
+ Batch batch = new Batch(Common.StandardQuery, 1, Common.GetFileStreamFactory());
batch.Execute(GetConnection(ci), CancellationToken.None).Wait();
// Then:
@@ -207,7 +205,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... I create a batch that has an empty query
// Then:
// ... It should throw an exception
- Assert.Throws(() => new Batch(query, 1));
+ Assert.Throws(() => new Batch(query, 1, Common.GetFileStreamFactory()));
+ }
+
+ [Fact]
+ public void BatchNoBufferFactory()
+ {
+ // If:
+ // ... I create a batch that has no file stream factory
+ // Then:
+ // ... It should throw an exception
+ Assert.Throws(() => new Batch("stuff", 1, null));
}
#endregion
@@ -222,7 +230,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// Then:
// ... It should throw an exception
Assert.Throws(() =>
- new Query(null, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings()));
+ new Query(null, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory()));
}
[Fact]
@@ -232,7 +240,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... I create a query that has a null connection info
// Then:
// ... It should throw an exception
- Assert.Throws(() => new Query("Some Query", null, new QueryExecutionSettings()));
+ Assert.Throws(() => new Query("Some Query", null, new QueryExecutionSettings(), Common.GetFileStreamFactory()));
}
[Fact]
@@ -243,7 +251,18 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// Then:
// ... It should throw an exception
Assert.Throws(() =>
- new Query("Some query", Common.CreateTestConnectionInfo(null, false), null));
+ new Query("Some query", Common.CreateTestConnectionInfo(null, false), null, Common.GetFileStreamFactory()));
+ }
+
+ [Fact]
+ public void QueryExecuteNoBufferFactory()
+ {
+ // If:
+ // ... I create a query that has a null file stream factory
+ // Then:
+ // ... It should throw an exception
+ Assert.Throws(() =>
+ new Query("Some query", Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(),null));
}
[Fact]
@@ -252,7 +271,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// If:
// ... I create a query from a single batch (without separator)
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false);
- Query query = new Query(Common.StandardQuery, ci, new QueryExecutionSettings());
+ Query query = new Query(Common.StandardQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory());
// Then:
// ... I should get a single batch to execute that hasn't been executed
@@ -279,7 +298,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// If:
// ... I create a query from a single batch that does nothing
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false);
- Query query = new Query(Common.NoOpQuery, ci, new QueryExecutionSettings());
+ Query query = new Query(Common.NoOpQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory());
// Then:
// ... I should get no batches back
@@ -305,7 +324,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... I create a query from two batches (with separator)
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false);
string queryText = string.Format("{0}\r\nGO\r\n{0}", Common.StandardQuery);
- Query query = new Query(queryText, ci, new QueryExecutionSettings());
+ Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory());
// Then:
// ... I should get back two batches to execute that haven't been executed
@@ -333,7 +352,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... I create a query from a two batches (with separator)
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, false);
string queryText = string.Format("{0}\r\nGO\r\n{1}", Common.StandardQuery, Common.NoOpQuery);
- Query query = new Query(queryText, ci, new QueryExecutionSettings());
+ Query query = new Query(queryText, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory());
// Then:
// ... I should get back one batch to execute that hasn't been executed
@@ -359,7 +378,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// If:
// ... I create a query from an invalid batch
ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true);
- Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings());
+ Query query = new Query(Common.InvalidQuery, ci, new QueryExecutionSettings(), Common.GetFileStreamFactory());
// Then:
// ... I should get back a query with one batch not executed
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs
index a6f5e9fe..2968e709 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs
@@ -28,7 +28,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
Batch b = Common.GetBasicExecutedBatch();
// ... And I ask for a subset with valid arguments
- ResultSetSubset subset = b.GetSubset(0, 0, rowCount);
+ ResultSetSubset subset = b.GetSubset(0, 0, rowCount).Result;
// Then:
// I should get the requested number of rows
@@ -51,7 +51,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... And I ask for a subset with an invalid result set index
// Then:
// ... It should throw an exception
- Assert.Throws(() => b.GetSubset(resultSetIndex, rowStartInex, rowCount));
+ Assert.ThrowsAsync(() => b.GetSubset(resultSetIndex, rowStartInex, rowCount)).Wait();
}
#endregion
@@ -62,12 +62,12 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
public void SubsetUnexecutedQueryTest()
{
// If I have a query that has *not* been executed
- Query q = new Query(Common.StandardQuery, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings());
+ Query q = new Query(Common.StandardQuery, Common.CreateTestConnectionInfo(null, false), new QueryExecutionSettings(), Common.GetFileStreamFactory());
// ... And I ask for a subset with valid arguments
// Then:
// ... It should throw an exception
- Assert.Throws(() => q.GetSubset(0, 0, 0, 2));
+ Assert.ThrowsAsync(() => q.GetSubset(0, 0, 0, 2)).Wait();
}
[Theory]
@@ -81,7 +81,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution
// ... And I ask for a subset with an invalid result set index
// Then:
// ... It should throw an exception
- Assert.Throws(() => q.GetSubset(batchIndex, 0, 0, 1));
+ Assert.ThrowsAsync(() => q.GetSubset(batchIndex, 0, 0, 1)).Wait();
}
#endregion
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs
new file mode 100644
index 00000000..c2765783
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbColumn.cs
@@ -0,0 +1,21 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System.Data.Common;
+
+namespace Microsoft.SqlTools.ServiceLayer.Test.Utility
+{
+ public class TestDbColumn : DbColumn
+ {
+ public TestDbColumn()
+ {
+ base.IsLong = false;
+ base.ColumnName = "Test Column";
+ base.ColumnSize = 128;
+ base.AllowDBNull = true;
+ base.DataType = typeof(string);
+ base.DataTypeName = "nvarchar";
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs
index e2003789..0330cda0 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs
@@ -64,6 +64,15 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility
return this[ordinal];
}
+ public override int GetValues(object[] values)
+ {
+ for(int i = 0; i < Rows.Current.Count; i++)
+ {
+ values[i] = this[i];
+ }
+ return Rows.Current.Count;
+ }
+
public override object this[string name]
{
get { return Rows.Current[name]; }
@@ -84,11 +93,16 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility
List columns = new List();
for (int i = 0; i < ResultSet.Current[0].Count; i++)
{
- columns.Add(new Mock().Object);
+ columns.Add(new TestDbColumn());
}
return new ReadOnlyCollection(columns);
}
+ public override bool IsDBNull(int ordinal)
+ {
+ return this[ordinal] == null;
+ }
+
public override int FieldCount { get { return Rows?.Current.Count ?? 0; } }
public override int RecordsAffected
@@ -189,16 +203,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.Utility
throw new NotImplementedException();
}
- public override int GetValues(object[] values)
- {
- throw new NotImplementedException();
- }
-
- public override bool IsDBNull(int ordinal)
- {
- throw new NotImplementedException();
- }
-
public override IEnumerator GetEnumerator()
{
throw new NotImplementedException();