//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Kusto.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.Utility;
namespace Microsoft.Kusto.ServiceLayer.QueryExecution.DataStorage
{
///
/// Wrapper around a DbData reader to perform some special operations more simply
///
public class StorageDataReader
{
///
/// Constructs a new wrapper around the provided reader
///
/// The reader to wrap around
public StorageDataReader(IDataReader reader)
{
// Sanity check to make sure there is a data reader
Validate.IsNotNull(nameof(reader), reader);
DataReader = reader;
// Read the columns into a set of wrappers
List columnList = new List();
var rows = DataReader.GetSchemaTable().Rows;
foreach (DataRow row in rows)
{
columnList.Add(new DbColumnWrapper(row));
}
Columns = columnList.ToArray();
HasLongColumns = Columns.Any(column => column.IsLong.HasValue && column.IsLong.Value);
}
#region Properties
///
/// All the columns that this reader currently contains
///
public DbColumnWrapper[] Columns { get; private set; }
///
/// The that will be read from
///
public IDataReader DataReader { get; private set; }
///
/// Whether or not any of the columns of this reader are 'long', such as nvarchar(max)
///
public bool HasLongColumns { get; private set; }
#endregion
#region DbDataReader Methods
///
/// Pass-through to DbDataReader.ReadAsync()
///
/// The cancellation token to use for cancelling a query
///
public Task ReadAsync(CancellationToken cancellationToken)
{
return Task.Run(() => DataReader.Read());
}
///
/// Retrieves a value
///
/// Column ordinal
/// The value of the given column
public object GetValue(int i)
{
return DataReader.GetValue(i);
}
///
/// Stores all values of the current row into the provided object array
///
/// Where to store the values from this row
public void GetValues(object[] values)
{
DataReader.GetValues(values);
}
///
/// Whether or not the cell of the given column at the current row is a DBNull
///
/// Column ordinal
/// True if the cell is DBNull, false otherwise
public bool IsDBNull(int i)
{
return DataReader.IsDBNull(i);
}
#endregion
#region Public Methods
///
/// Retrieves bytes with a maximum number of bytes to return
///
/// Column ordinal
/// Number of bytes to return at maximum
/// Byte array
public byte[] GetBytesWithMaxCapacity(int iCol, int maxNumBytesToReturn)
{
if (maxNumBytesToReturn <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxNumBytesToReturn), SR.QueryServiceDataReaderByteCountInvalid);
}
//first, ask provider how much data it has and calculate the final # of bytes
//NOTE: -1 means that it doesn't know how much data it has
long neededLength;
long origLength = neededLength = GetBytes(iCol, 0, null, 0, 0);
if (neededLength == -1 || neededLength > maxNumBytesToReturn)
{
neededLength = maxNumBytesToReturn;
}
//get the data up to the maxNumBytesToReturn
byte[] bytesBuffer = new byte[neededLength];
GetBytes(iCol, 0, bytesBuffer, 0, (int)neededLength);
//see if server sent back more data than we should return
if (origLength == -1 || origLength > neededLength)
{
//pump the rest of data from the reader and discard it right away
long dataIndex = neededLength;
const int tmpBufSize = 100000;
byte[] tmpBuf = new byte[tmpBufSize];
while (GetBytes(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize)
{
dataIndex += tmpBufSize;
}
}
return bytesBuffer;
}
///
/// Retrieves characters with a maximum number of charss to return
///
/// Column ordinal
/// Number of chars to return at maximum
/// String
public string GetCharsWithMaxCapacity(int iCol, int maxCharsToReturn)
{
if (maxCharsToReturn <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxCharsToReturn), SR.QueryServiceDataReaderCharCountInvalid);
}
//first, ask provider how much data it has and calculate the final # of chars
//NOTE: -1 means that it doesn't know how much data it has
long neededLength;
long origLength = neededLength = GetChars(iCol, 0, null, 0, 0);
if (neededLength == -1 || neededLength > maxCharsToReturn)
{
neededLength = maxCharsToReturn;
}
Debug.Assert(neededLength < int.MaxValue);
//get the data up to maxCharsToReturn
char[] buffer = new char[neededLength];
if (neededLength > 0)
{
GetChars(iCol, 0, buffer, 0, (int)neededLength);
}
//see if server sent back more data than we should return
if (origLength == -1 || origLength > neededLength)
{
//pump the rest of data from the reader and discard it right away
long dataIndex = neededLength;
const int tmpBufSize = 100000;
char[] tmpBuf = new char[tmpBufSize];
while (GetChars(iCol, dataIndex, tmpBuf, 0, tmpBufSize) == tmpBufSize)
{
dataIndex += tmpBufSize;
}
}
string res = new string(buffer);
return res;
}
///
/// Retrieves xml with a maximum number of bytes to return
///
/// Column ordinal
/// Number of chars to return at maximum
/// String
public string GetXmlWithMaxCapacity(int iCol, int maxCharsToReturn)
{
if (maxCharsToReturn <= 0)
{
throw new ArgumentOutOfRangeException(nameof(maxCharsToReturn), SR.QueryServiceDataReaderXmlCountInvalid);
}
object o = GetValue(iCol);
return o?.ToString();
}
#endregion
#region Private Helpers
private long GetBytes(int i, long dataIndex, byte[] buffer, int bufferIndex, int length)
{
return DataReader.GetBytes(i, dataIndex, buffer, bufferIndex, length);
}
private long GetChars(int i, long dataIndex, char[] buffer, int bufferIndex, int length)
{
return DataReader.GetChars(i, dataIndex, buffer, bufferIndex, length);
}
#endregion
///
/// Internal class for writing strings with a maximum capacity
///
///
/// This code is take almost verbatim from Microsoft.SqlServer.Management.UI.Grid, SSMS
/// DataStorage, StorageDataReader class.
///
internal class StringWriterWithMaxCapacity : StringWriter
{
private bool stopWriting;
private int CurrentLength
{
get { return GetStringBuilder().Length; }
}
public StringWriterWithMaxCapacity(IFormatProvider formatProvider, int capacity) : base(formatProvider)
{
MaximumCapacity = capacity;
}
private int MaximumCapacity { get; set; }
public override void Write(char value)
{
if (stopWriting) { return; }
if (CurrentLength < MaximumCapacity)
{
base.Write(value);
}
else
{
stopWriting = true;
}
}
public override void Write(char[] buffer, int index, int count)
{
if (stopWriting) { return; }
int curLen = CurrentLength;
if (curLen + (count - index) > MaximumCapacity)
{
stopWriting = true;
count = MaximumCapacity - curLen + index;
if (count < 0)
{
count = 0;
}
}
base.Write(buffer, index, count);
}
public override void Write(string value)
{
if (stopWriting) { return; }
int curLen = CurrentLength;
if (value.Length + curLen > MaximumCapacity)
{
stopWriting = true;
base.Write(value.Substring(0, MaximumCapacity - curLen));
}
else
{
base.Write(value);
}
}
}
}
}