From 9f371cd0bca292e5796f0188bb371d8d9eeb4577 Mon Sep 17 00:00:00 2001 From: Benjamin Russell Date: Fri, 5 Aug 2016 18:38:21 -0700 Subject: [PATCH] Unit tests, part 1 --- .../QueryExecution/Query.cs | 14 +- .../QueryExecution/ExecuteTests.cs | 201 ++++++++++++++-- .../Utility/TestDbDataReader.cs | 215 ++++++++++++++++++ .../Utility/TestObjects.cs | 189 ++------------- .../project.json | 3 +- 5 files changed, 431 insertions(+), 191 deletions(-) create mode 100644 test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs index 3e189ece..831e981b 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -84,19 +84,25 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution { do { - // Create a new result set that we'll use to store all the data - ResultSet resultSet = new ResultSet(); - if (reader.CanGetColumnSchema()) + // TODO: This doesn't properly handle scenarios where the query is SELECT but does not have rows + if (!reader.HasRows) { - resultSet.Columns = reader.GetColumnSchema().ToArray(); + continue; } // Read until we hit the end of the result set + ResultSet resultSet = new ResultSet(); while (await reader.ReadAsync(cancellationSource.Token)) { resultSet.AddRow(reader); } + // Read off the column schema information + if (reader.CanGetColumnSchema()) + { + resultSet.Columns = reader.GetColumnSchema().ToArray(); + } + // Add the result set to the results of the query ResultSets.Add(resultSet); } while (await reader.NextResultAsync(cancellationSource.Token)); diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs index 4e2ccab1..3eb923c5 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -1,24 +1,35 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.SqlServer.Management.SqlParser.MetadataProvider; +using System.Collections.ObjectModel; +using System.Data; +using System.Data.Common; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution; -using Microsoft.SqlTools.Test.Utility; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Moq; +using Moq.Protected; using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { public class ExecuteTests { + private static Dictionary[] testData = + { + new Dictionary { {"col1", "val11"}, { "col2", "val12"}, { "col3", "val13"}, { "col4", "col14"} }, + new Dictionary { {"col1", "val21"}, { "col2", "val22"}, { "col3", "val23"}, { "col4", "col24"} }, + new Dictionary { {"col1", "val31"}, { "col2", "val32"}, { "col3", "val33"}, { "col4", "col34"} }, + new Dictionary { {"col1", "val41"}, { "col2", "val42"}, { "col3", "val43"}, { "col4", "col44"} }, + new Dictionary { {"col1", "val51"}, { "col2", "val52"}, { "col3", "val53"}, { "col4", "col54"} }, + }; + [Fact] public void QueryCreationTest() { // If I create a new query... - Query query = new Query("NO OP", CreateTestConnectionInfo()); + Query query = new Query("NO OP", CreateTestConnectionInfo(null)); // Then: // ... It should not have executed @@ -29,7 +40,173 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution Assert.Empty(query.ResultSummary); } - private static ConnectionInfo CreateTestConnectionInfo() + [Fact] + public void QueryExecuteNoResultSets() + { + // If I execute a query that should get no result sets + Query query = new Query("Query with no result sets", CreateTestConnectionInfo(null)); + query.Execute().Wait(); + + // Then: + // ... It should have executed + Assert.True(query.HasExecuted, "The query should have been marked executed."); + + // ... The results should be empty + Assert.Empty(query.ResultSets); + Assert.Empty(query.ResultSummary); + + // ... The results should not be null + Assert.NotNull(query.ResultSets); + Assert.NotNull(query.ResultSummary); + } + + [Fact] + public void QueryExecuteQueryOneResultSet() + { + ConnectionInfo ci = CreateTestConnectionInfo(new[] {testData}); + + // If I execute a query that should get one result set + int resultSets = 1; + int rows = 5; + int columns = 4; + Query query = new Query("Query with one result sets", ci); + query.Execute().Wait(); + + // Then: + // ... It should have executed + Assert.True(query.HasExecuted, "The query should have been marked executed."); + + // ... There should be exactly one result set + Assert.Equal(resultSets, query.ResultSets.Count); + + // ... Inside the result set should be with 5 rows + Assert.Equal(rows, query.ResultSets[0].Rows.Count); + + // ... Inside the result set should have 5 columns and 5 column definitions + Assert.Equal(columns, query.ResultSets[0].Rows[0].Length); + Assert.Equal(columns, query.ResultSets[0].Columns.Length); + + // ... There should be exactly one result set summary + Assert.Equal(resultSets, query.ResultSummary.Length); + + // ... Inside the result summary, there should be 5 column definitions + Assert.Equal(columns, query.ResultSummary[0].ColumnInfo.Length); + + // ... Inside the result summary, there should be 5 rows + Assert.Equal(rows, query.ResultSummary[0].RowCount); + } + + [Fact] + public void QueryExecuteQueryTwoResultSets() + { + var dataset = new[] {testData, testData}; + int resultSets = dataset.Length; + int rows = testData.Length; + int columns = testData[0].Count; + ConnectionInfo ci = CreateTestConnectionInfo(dataset); + + // If I execute a query that should get two result sets + Query query = new Query("Query with two result sets", ci); + query.Execute().Wait(); + + // Then: + // ... It should have executed + Assert.True(query.HasExecuted, "The query should have been marked executed."); + + // ... There should be exactly two result sets + Assert.Equal(resultSets, query.ResultSets.Count); + + foreach (ResultSet rs in query.ResultSets) + { + // ... Each result set should have 5 rows + Assert.Equal(rows, rs.Rows.Count); + + // ... Inside each result set should be 5 columns and 5 column definitions + Assert.Equal(columns, rs.Rows[0].Length); + Assert.Equal(columns, rs.Columns.Length); + } + + // ... There should be exactly two result set summaries + Assert.Equal(resultSets, query.ResultSummary.Length); + + foreach (ResultSetSummary rs in query.ResultSummary) + { + // ... Inside each result summary, there should be 5 column definitions + Assert.Equal(columns, rs.ColumnInfo.Length); + + // ... Inside each result summary, there should be 5 rows + Assert.Equal(rows, rs.RowCount); + } + } + + #region Mocking + + //private static DbDataReader CreateTestReader(int columnCount, int rowCount) + //{ + // var readerMock = new Mock { CallBase = true }; + + // // Setup for column reads + // // TODO: We can't test columns because of oddities with how datatable/GetColumn + + // // Setup for row reads + // var readSequence = readerMock.SetupSequence(dbReader => dbReader.Read()); + // for (int i = 0; i < rowCount; i++) + // { + // readSequence.Returns(true); + // } + // readSequence.Returns(false); + + // // Make sure that if we call for data from the reader it works + // readerMock.Setup(dbReader => dbReader[InColumnRange(columnCount)]) + // .Returns(i => i.ToString()); + // readerMock.Setup(dbReader => dbReader[NotInColumnRange(columnCount)]) + // .Throws(new ArgumentOutOfRangeException()); + // readerMock.Setup(dbReader => dbReader.HasRows) + // .Returns(rowCount > 0); + + // return readerMock.Object; + //} + + //private static int InColumnRange(int columnCount) + //{ + // return Match.Create(i => i < columnCount && i > 0); + //} + + //private static int NotInColumnRange(int columnCount) + //{ + // return Match.Create(i => i >= columnCount || i < 0); + //} + + private static DbCommand CreateTestCommand(Dictionary[][] data) + { + var commandMock = new Mock {CallBase = true}; + commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()) + .Returns(new TestDbDataReader(data)); + + return commandMock.Object; + } + + private static DbConnection CreateTestConnection(Dictionary[][] data) + { + var connectionMock = new Mock {CallBase = true}; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data)); + + return connectionMock.Object; + } + + private static ISqlConnectionFactory CreateMockFactory(Dictionary[][] data) + { + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateTestConnection(data)); + + return mockFactory.Object; + } + + private static ConnectionInfo CreateTestConnectionInfo(Dictionary[][] data) { // Create connection info ConnectionDetails connDetails = new ConnectionDetails @@ -40,15 +217,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution ServerName = "sqltools11" }; -#if !USE_LIVE_CONNECTION - // Use the mock db connection factory - ISqlConnectionFactory factory = new TestSqlConnectionFactory(); -#else - // Use a real db connection factory - ISqlConnectionFactory factory = new SqlConnectionFactory(); -#endif - - return new ConnectionInfo(factory, "test://test", connDetails); + return new ConnectionInfo(CreateMockFactory(data), "test://test", connDetails); } + + #endregion } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs new file mode 100644 index 00000000..0031ad4a --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs @@ -0,0 +1,215 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Data.Common; +using System.Linq; +using Moq; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public class TestDbDataReader : DbDataReader, IDbColumnSchemaGenerator + { + + #region Test Specific Implementations + + private Dictionary[][] Data { get; set; } + + public IEnumerator[]> ResultSet { get; private set; } + + private IEnumerator> Rows { get; set; } + + private const string tableNameTestCommand = "SELECT name FROM sys.tables"; + + private List> tableNamesTest = new List> + { + new Dictionary { {"name", "table1"} }, + new Dictionary { {"name", "table2"} } + }; + + public TestDbDataReader(Dictionary[][] data) + { + Data = data; + if (Data != null) + { + ResultSet = ((IEnumerable[]>) Data).GetEnumerator(); + ResultSet.MoveNext(); + } + } + + #endregion + + public override bool HasRows + { + get { return ResultSet != null && ResultSet.Current.Length > 0; } + } + + public override bool Read() + { + if (Rows == null) + { + Rows = ((IEnumerable>) ResultSet.Current).GetEnumerator(); + } + return Rows.MoveNext(); + } + + public override bool NextResult() + { + if (Data == null || !ResultSet.MoveNext()) + { + return false; + } + Rows = ((IEnumerable>)ResultSet.Current).GetEnumerator(); + return true; + } + + public override object GetValue(int ordinal) + { + return this[ordinal]; + } + + public override object this[string name] + { + get { return Rows.Current[name]; } + } + + public override object this[int ordinal] + { + get { return Rows.Current[Rows.Current.Keys.AsEnumerable().ToArray()[ordinal]]; } + } + + public ReadOnlyCollection GetColumnSchema() + { + if (ResultSet?.Current == null || ResultSet.Current.Length <= 0) + { + return new ReadOnlyCollection(new List()); + } + + List columns = new List(); + for (int i = 0; i < ResultSet.Current[0].Count; i++) + { + columns.Add(new Mock().Object); + } + return new ReadOnlyCollection(columns); + } + + public override int FieldCount { get { return Rows?.Current.Count ?? 0; } } + + #region Not Implemented + + public override bool GetBoolean(int ordinal) + { + throw new NotImplementedException(); + } + + public override byte GetByte(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + { + throw new NotImplementedException(); + } + + public override char GetChar(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + { + throw new NotImplementedException(); + } + + public override string GetDataTypeName(int ordinal) + { + throw new NotImplementedException(); + } + + public override DateTime GetDateTime(int ordinal) + { + throw new NotImplementedException(); + } + + public override decimal GetDecimal(int ordinal) + { + throw new NotImplementedException(); + } + + public override double GetDouble(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetOrdinal(string name) + { + throw new NotImplementedException(); + } + + public override string GetName(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetInt64(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetInt32(int ordinal) + { + throw new NotImplementedException(); + } + + public override short GetInt16(int ordinal) + { + throw new NotImplementedException(); + } + + public override Guid GetGuid(int ordinal) + { + throw new NotImplementedException(); + } + + public override float GetFloat(int ordinal) + { + throw new NotImplementedException(); + } + + public override Type GetFieldType(int ordinal) + { + throw new NotImplementedException(); + } + + public override string GetString(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetValues(object[] values) + { + throw new NotImplementedException(); + } + + public override bool IsDBNull(int ordinal) + { + throw new NotImplementedException(); + } + + public override IEnumerator GetEnumerator() + { + throw new NotImplementedException(); + } + + public override int Depth { get; } + public override bool IsClosed { get; } + public override int RecordsAffected { get; } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs index cda0ed5a..5ca94d2b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs @@ -18,6 +18,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Xunit; namespace Microsoft.SqlTools.Test.Utility @@ -97,179 +98,18 @@ namespace Microsoft.SqlTools.Test.Utility } } - public class TestDataReader : DbDataReader - { - - #region Test Specific Implementations - - internal string SqlCommandText { get; set; } - - private const string tableNameTestCommand = "SELECT name FROM sys.tables"; - - private List> tableNamesTest = new List> - { - new Dictionary { {"name", "table1"} }, - new Dictionary { {"name", "table2"} } - }; - - private IEnumerator> tableEnumerator; - - #endregion - - public override bool GetBoolean(int ordinal) - { - throw new NotImplementedException(); - } - - public override byte GetByte(int ordinal) - { - throw new NotImplementedException(); - } - - public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) - { - throw new NotImplementedException(); - } - - public override char GetChar(int ordinal) - { - throw new NotImplementedException(); - } - - public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) - { - throw new NotImplementedException(); - } - - public override string GetDataTypeName(int ordinal) - { - throw new NotImplementedException(); - } - - public override DateTime GetDateTime(int ordinal) - { - throw new NotImplementedException(); - } - - public override decimal GetDecimal(int ordinal) - { - throw new NotImplementedException(); - } - - public override double GetDouble(int ordinal) - { - throw new NotImplementedException(); - } - - public override IEnumerator GetEnumerator() - { - throw new NotImplementedException(); - } - - public override int GetOrdinal(string name) - { - throw new NotImplementedException(); - } - - public override string GetName(int ordinal) - { - throw new NotImplementedException(); - } - - public override long GetInt64(int ordinal) - { - throw new NotImplementedException(); - } - - public override int GetInt32(int ordinal) - { - throw new NotImplementedException(); - } - - public override short GetInt16(int ordinal) - { - throw new NotImplementedException(); - } - - public override Guid GetGuid(int ordinal) - { - throw new NotImplementedException(); - } - - public override float GetFloat(int ordinal) - { - throw new NotImplementedException(); - } - - public override Type GetFieldType(int ordinal) - { - throw new NotImplementedException(); - } - - public override string GetString(int ordinal) - { - throw new NotImplementedException(); - } - - public override object GetValue(int ordinal) - { - throw new NotImplementedException(); - } - - public override int GetValues(object[] values) - { - throw new NotImplementedException(); - } - - public override bool IsDBNull(int ordinal) - { - throw new NotImplementedException(); - } - - public override bool NextResult() - { - throw new NotImplementedException(); - } - - public override bool Read() - { - if (tableEnumerator == null) - { - switch (SqlCommandText) - { - case tableNameTestCommand: - tableEnumerator = ((IEnumerable>)tableNamesTest).GetEnumerator(); - break; - default: - throw new NotImplementedException(); - } - } - return tableEnumerator.MoveNext(); - } - - public override int Depth { get; } - public override bool IsClosed { get; } - public override int RecordsAffected { get; } - - public override object this[string name] - { - get { return tableEnumerator.Current[name]; } - } - - public override object this[int ordinal] - { - get { return tableEnumerator.Current[tableEnumerator.Current.Keys.ToArray()[ordinal]]; } - } - - public override int FieldCount { get; } - public override bool HasRows { get; } - } - /// /// Test mock class for IDbCommand /// public class TestSqlCommand : DbCommand { + internal TestSqlCommand(Dictionary[][] data) + { + Data = data; + } + + internal Dictionary[][] Data { get; set; } + public override void Cancel() { throw new NotImplementedException(); @@ -306,7 +146,7 @@ namespace Microsoft.SqlTools.Test.Utility protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) { - return new TestDataReader {SqlCommandText = CommandText}; + return new TestDbDataReader(Data); } } @@ -315,6 +155,13 @@ namespace Microsoft.SqlTools.Test.Utility /// public class TestSqlConnection : DbConnection { + internal TestSqlConnection(Dictionary[][] data) + { + Data = data; + } + + internal Dictionary[][] Data { get; set; } + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) { throw new NotImplementedException(); @@ -342,7 +189,7 @@ namespace Microsoft.SqlTools.Test.Utility protected override DbCommand CreateDbCommand() { - return new TestSqlCommand(); + return new TestSqlCommand(Data); } public override void ChangeDatabase(string databaseName) @@ -358,7 +205,7 @@ namespace Microsoft.SqlTools.Test.Utility { public DbConnection CreateSqlConnection(string connectionString) { - return new TestSqlConnection() + return new TestSqlConnection(null) { ConnectionString = connectionString }; diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json index 3d023cd4..23c97d0b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json @@ -14,7 +14,8 @@ "moq.netcore": "4.4.0-beta8", "Microsoft.SqlTools.ServiceLayer": { "target": "project" - } + }, + "System.Diagnostics.TraceSource": "4.0.0" }, "testRunner": "xunit", "frameworks": {