// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections.Generic; using System.Data; using System.Diagnostics; using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Kusto.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.Utility; namespace Microsoft.Kusto.ServiceLayer.QueryExecution.DataStorage { /// /// Wrapper around a DbData reader to perform some special operations more simply /// public class StorageDataReader { /// /// Constructs a new wrapper around the provided reader /// /// The reader to wrap around public StorageDataReader(IDataReader reader) { // Sanity check to make sure there is a data reader Validate.IsNotNull(nameof(reader), reader); DataReader = reader; // Read the columns into a set of wrappers List columnList = new List(); var rows = DataReader.GetSchemaTable().Rows; foreach (DataRow row in rows) { columnList.Add(new DbColumnWrapper(row)); } Columns = columnList.ToArray(); HasLongColumns = Columns.Any(column => column.IsLong.HasValue && column.IsLong.Value); } #region Properties /// /// All the columns that this reader currently contains /// public DbColumnWrapper[] Columns { get; private set; } /// /// The that will be read from /// public IDataReader DataReader { get; private set; } /// /// Whether or not any of the columns of this reader are 'long', such as nvarchar(max) /// public bool HasLongColumns { get; private set; } #endregion #region DbDataReader Methods /// /// Pass-through to DbDataReader.ReadAsync() /// /// The cancellation token to use for cancelling a query /// public Task ReadAsync(CancellationToken cancellationToken) { return Task.Run(() => DataReader.Read()); } /// /// Retrieves a value /// /// Column ordinal /// The value of the given column public object GetValue(int i) { return DataReader.GetValue(i); } /// /// Stores all values of the current row into the provided object array /// /// Where to store the values from this row public void GetValues(object[] values) { DataReader.GetValues(values); } /// /// Whether or not the cell of the given column at the current row is a DBNull /// /// Column ordinal /// True if the cell is DBNull, false otherwise public bool IsDBNull(int i) { return DataReader.IsDBNull(i); } #endregion #region Public Methods /// /// Retrieves bytes with a maximum number of bytes to return /// /// Column ordinal /// Number of bytes to return at maximum /// Byte array public byte[] GetBytesWithMaxCapacity(int iCol, int maxNumBytesToReturn) { if (maxNumBytesToReturn <= 0) { throw new ArgumentOutOfRangeException(nameof(maxNumBytesToReturn), SR.QueryServiceDataReaderByteCountInvalid); } //first, ask provider how much data it has and calculate the final # of bytes //NOTE: -1 means that it doesn't know how much data it has long neededLength; long origLength = neededLength = GetBytes(iCol, 0, null, 0, 0); if (neededLength == -1 || neededLength > maxNumBytesToReturn) { neededLength = maxNumBytesToReturn; } //get the data up to the maxNumBytesToReturn byte[] bytesBuffer = new byte[neededLength]; GetBytes(iCol, 0, bytesBuffer, 0, (int)neededLength); //see if server sent back more data than we should return if (origLength == -1 || origLength > neededLength) { //pump the rest of data from the reader and discard it right away long dataIndex = neededLength; const int tmpBufSize = 100000; byte[] tmpBuf = new byte[tmpBufSize]; while (GetBytes(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize) { dataIndex += tmpBufSize; } } return bytesBuffer; } /// /// Retrieves characters with a maximum number of charss to return /// /// Column ordinal /// Number of chars to return at maximum /// String public string GetCharsWithMaxCapacity(int iCol, int maxCharsToReturn) { if (maxCharsToReturn <= 0) { throw new ArgumentOutOfRangeException(nameof(maxCharsToReturn), SR.QueryServiceDataReaderCharCountInvalid); } //first, ask provider how much data it has and calculate the final # of chars //NOTE: -1 means that it doesn't know how much data it has long neededLength; long origLength = neededLength = GetChars(iCol, 0, null, 0, 0); if (neededLength == -1 || neededLength > maxCharsToReturn) { neededLength = maxCharsToReturn; } Debug.Assert(neededLength < int.MaxValue); //get the data up to maxCharsToReturn char[] buffer = new char[neededLength]; if (neededLength > 0) { GetChars(iCol, 0, buffer, 0, (int)neededLength); } //see if server sent back more data than we should return if (origLength == -1 || origLength > neededLength) { //pump the rest of data from the reader and discard it right away long dataIndex = neededLength; const int tmpBufSize = 100000; char[] tmpBuf = new char[tmpBufSize]; while (GetChars(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize) { dataIndex += tmpBufSize; } } string res = new string(buffer); return res; } /// /// Retrieves xml with a maximum number of bytes to return /// /// Column ordinal /// Number of chars to return at maximum /// String public string GetXmlWithMaxCapacity(int iCol, int maxCharsToReturn) { if (maxCharsToReturn <= 0) { throw new ArgumentOutOfRangeException(nameof(maxCharsToReturn), SR.QueryServiceDataReaderXmlCountInvalid); } object o = GetValue(iCol); return o?.ToString(); } #endregion #region Private Helpers private long GetBytes(int i, long dataIndex, byte[] buffer, int bufferIndex, int length) { return DataReader.GetBytes(i, dataIndex, buffer, bufferIndex, length); } private long GetChars(int i, long dataIndex, char[] buffer, int bufferIndex, int length) { return DataReader.GetChars(i, dataIndex, buffer, bufferIndex, length); } #endregion /// /// Internal class for writing strings with a maximum capacity /// /// /// This code is take almost verbatim from Microsoft.SqlServer.Management.UI.Grid, SSMS /// DataStorage, StorageDataReader class. /// internal class StringWriterWithMaxCapacity : StringWriter { private bool stopWriting; private int CurrentLength { get { return GetStringBuilder().Length; } } public StringWriterWithMaxCapacity(IFormatProvider formatProvider, int capacity) : base(formatProvider) { MaximumCapacity = capacity; } private int MaximumCapacity { get; set; } public override void Write(char value) { if (stopWriting) { return; } if (CurrentLength < MaximumCapacity) { base.Write(value); } else { stopWriting = true; } } public override void Write(char[] buffer, int index, int count) { if (stopWriting) { return; } int curLen = CurrentLength; if (curLen + (count - index) > MaximumCapacity) { stopWriting = true; count = MaximumCapacity - curLen + index; if (count < 0) { count = 0; } } base.Write(buffer, index, count); } public override void Write(string value) { if (stopWriting) { return; } int curLen = CurrentLength; if (value.Length + curLen > MaximumCapacity) { stopWriting = true; base.Write(value.Substring(0, MaximumCapacity - curLen)); } else { base.Write(value); } } } } }