// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data.Common; using Microsoft.Data.SqlClient; using System.Linq; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { public class TestDbException : DbException { } public class TestDbDataReader : DbDataReader, IDbColumnSchemaGenerator { #region Test Specific Implementations private IEnumerable Data { get; } private IEnumerator ResultSetEnumerator { get; } private IEnumerator RowEnumerator { get; set; } private bool ThrowOnRead { get; } public TestDbDataReader(IEnumerable data, bool throwOnRead) { ThrowOnRead = throwOnRead; Data = data; if (Data != null) { ResultSetEnumerator = Data.GetEnumerator(); ResultSetEnumerator.MoveNext(); } } #endregion #region Properties public override int FieldCount => ResultSetEnumerator?.Current.Columns.Count ?? 0; public override bool HasRows => ResultSetEnumerator?.Current.Rows.Count > 0; /// /// Mimicks the behavior of SqlDbDataReader /// public override int RecordsAffected => RowEnumerator != null ? -1 : 1; public override object this[int ordinal] => RowEnumerator.Current[ordinal]; #endregion #region Implemented Methods /// /// If the row enumerator hasn't been initialized for the current result set, the /// enumerator for the current result set is defined. Increments the enumerator /// /// True if tere were more rows, false otherwise public override bool Read() { if (ThrowOnRead) { throw new TestDbException(); } if (RowEnumerator == null) { RowEnumerator = ResultSetEnumerator.Current.GetEnumerator(); } return RowEnumerator.MoveNext(); } /// /// Increments the result set enumerator and initializes the row enumerator /// /// public override bool NextResult() { if (Data == null || !ResultSetEnumerator.MoveNext()) { return false; } RowEnumerator = ResultSetEnumerator.Current.GetEnumerator(); return true; } /// /// Retrieves the value for the cell of the current row in the given column /// /// Ordinal of the column /// The object in the cell public override object GetValue(int ordinal) { return this[ordinal]; } /// /// Stores the values of all cells in this row in the given object array /// /// Destination for all cell values /// Number of cells in the current row public override int GetValues(object[] values) { for (int i = 0; i < RowEnumerator.Current.Length; i++) { values[i] = this[i]; } return RowEnumerator.Current.Length; } /// /// Whether or not a given cell in the current row is null /// /// Ordinal of the column /// True if the cell is null, false otherwise public override bool IsDBNull(int ordinal) { return this[ordinal] == null; } /// Collection of test columns in the current result set public ReadOnlyCollection GetColumnSchema() { if (ResultSetEnumerator?.Current == null) { return new ReadOnlyCollection(new List()); } return new ReadOnlyCollection(ResultSetEnumerator.Current.Columns); } public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) { if (ResultSetEnumerator.Current.Columns[ordinal].DataType == typeof(byte[])) { byte[] data = (byte[])this[ordinal]; if (buffer == null) { return data.Length; } Array.Copy(data, (int)dataOffset, buffer, bufferOffset, length); return Math.Min(length, data.Length); } throw new InvalidOperationException(); } #endregion #region Not Implemented public override bool GetBoolean(int ordinal) { throw new NotImplementedException(); } public override byte GetByte(int ordinal) { throw new NotImplementedException(); } public override char GetChar(int ordinal) { throw new NotImplementedException(); } public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) { char[] allChars = ((string)RowEnumerator.Current[ordinal]).ToCharArray(); int outLength = allChars.Length; if (buffer != null) { Array.Copy(allChars, (int)dataOffset, buffer, bufferOffset, outLength); } return outLength; } public override string GetDataTypeName(int ordinal) { throw new NotImplementedException(); } public override DateTime GetDateTime(int ordinal) { throw new NotImplementedException(); } public override decimal GetDecimal(int ordinal) { throw new NotImplementedException(); } public override double GetDouble(int ordinal) { throw new NotImplementedException(); } public override int GetOrdinal(string name) { throw new NotImplementedException(); } public override string GetName(int ordinal) { throw new NotImplementedException(); } public override long GetInt64(int ordinal) { throw new NotImplementedException(); } public override int GetInt32(int ordinal) { string allChars = ((string) RowEnumerator.Current[ordinal]); int x = 0; if(allChars.Length != 1 || !Int32.TryParse(allChars.ToString(), out x) ) { throw new InvalidCastException(); } return x; } public override short GetInt16(int ordinal) { throw new NotImplementedException(); } public override Guid GetGuid(int ordinal) { throw new NotImplementedException(); } public override float GetFloat(int ordinal) { throw new NotImplementedException(); } public override Type GetFieldType(int ordinal) { throw new NotImplementedException(); } public override string GetString(int ordinal) { throw new NotImplementedException(); } public override IEnumerator GetEnumerator() { throw new NotImplementedException(); } public override object this[string name] { get { var column = ResultSetEnumerator?.Current.Columns.FindIndex(c => c.ColumnName == name); if (!column.HasValue) { throw new ArgumentOutOfRangeException(); } return RowEnumerator.Current[column.Value]; } } public override int Depth { get { throw new NotImplementedException(); } } public override bool IsClosed { get { throw new NotImplementedException(); } } #endregion } }