// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Moq; using Moq.Protected; namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { public class Common { #region Constants public const string InvalidQuery = "SELECT *** FROM sys.objects"; public const string NoOpQuery = "-- No ops here, just us chickens."; public const int Ordinal = 100; // We'll pick something other than default(int) public const string OwnerUri = "testFile"; public const int StandardColumns = 5; public const string StandardQuery = "SELECT * FROM sys.objects"; public const int StandardRows = 5; public const string UdtQuery = "SELECT hierarchyid::Parse('/')"; public const SelectionData WholeDocument = null; public static readonly ConnectionDetails StandardConnectionDetails = new ConnectionDetails { DatabaseName = "123", Password = "456", ServerName = "789", UserName = "012" }; public static readonly SelectionData SubsectionDocument = new SelectionData(0, 0, 2, 2); #endregion public static Dictionary[] StandardTestData { get { return GetTestData(StandardRows, StandardColumns); } } #region Public Methods public static Batch GetBasicExecutedBatch() { Batch batch = new Batch(StandardQuery, SubsectionDocument, 1, GetFileStreamFactory(new Dictionary())); batch.Execute(CreateTestConnection(new[] {StandardTestData}, false), CancellationToken.None).Wait(); return batch; } public static Query GetBasicExecutedQuery() { ConnectionInfo ci = CreateTestConnectionInfo(new[] {StandardTestData}, false); Query query = new Query(StandardQuery, ci, new QueryExecutionSettings(), GetFileStreamFactory(new Dictionary())); query.Execute(); query.ExecutionTask.Wait(); return query; } public static Dictionary[] GetTestData(int columns, int rows) { Dictionary[] output = new Dictionary[rows]; for (int row = 0; row < rows; row++) { Dictionary rowDictionary = new Dictionary(); for (int column = 0; column < columns; column++) { rowDictionary.Add(string.Format("column{0}", column), string.Format("val{0}{1}", column, row)); } output[row] = rowDictionary; } return output; } public static Dictionary[][] GetTestDataSet(int dataSets) { List[]> output = new List[]>(); for(int dataSet = 0; dataSet < dataSets; dataSet++) { output.Add(StandardTestData); } return output.ToArray(); } public static async Task AwaitExecution(QueryExecutionService service, QueryExecuteParams qeParams, RequestContext requestContext) { await service.HandleExecuteRequest(qeParams, requestContext); if (service.ActiveQueries.ContainsKey(qeParams.OwnerUri) && service.ActiveQueries[qeParams.OwnerUri].ExecutionTask != null) { await service.ActiveQueries[qeParams.OwnerUri].ExecutionTask; } } #endregion #region FileStreamWriteMocking public static IFileStreamFactory GetFileStreamFactory(Dictionary storage) { Mock mock = new Mock(); mock.Setup(fsf => fsf.CreateFile()) .Returns(() => { string fileName = Guid.NewGuid().ToString(); storage.Add(fileName, new byte[8192]); return fileName; }); mock.Setup(fsf => fsf.GetReader(It.IsAny())) .Returns(output => new ServiceBufferFileStreamReader(new MemoryStream(storage[output]))); mock.Setup(fsf => fsf.GetWriter(It.IsAny())) .Returns(output => new ServiceBufferFileStreamWriter(new MemoryStream(storage[output]), 1024, 1024)); return mock.Object; } #endregion #region DbConnection Mocking public static DbCommand CreateTestCommand(Dictionary[][] data, bool throwOnRead) { var commandMock = new Mock { CallBase = true }; var commandMockSetup = commandMock.Protected() .Setup("ExecuteDbDataReader", It.IsAny()); // Setup the expected behavior if (throwOnRead) { var mockException = new Mock(); mockException.SetupGet(dbe => dbe.Message).Returns("Message"); commandMockSetup.Throws(mockException.Object); } else { commandMockSetup.Returns(new TestDbDataReader(data)); } return commandMock.Object; } public static DbConnection CreateTestConnection(Dictionary[][] data, bool throwOnRead) { var connectionMock = new Mock { CallBase = true }; connectionMock.Protected() .Setup("CreateDbCommand") .Returns(() => CreateTestCommand(data, throwOnRead)); connectionMock.Setup(dbc => dbc.Open()) .Callback(() => connectionMock.SetupGet(dbc => dbc.State).Returns(ConnectionState.Open)); connectionMock.Setup(dbc => dbc.Close()) .Callback(() => connectionMock.SetupGet(dbc => dbc.State).Returns(ConnectionState.Closed)); return connectionMock.Object; } public static ISqlConnectionFactory CreateMockFactory(Dictionary[][] data, bool throwOnRead) { var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) .Returns(() => CreateTestConnection(data, throwOnRead)); return mockFactory.Object; } public static ConnectionInfo CreateTestConnectionInfo(Dictionary[][] data, bool throwOnRead) { return new ConnectionInfo(CreateMockFactory(data, throwOnRead), OwnerUri, StandardConnectionDetails); } #endregion #region Service Mocking public static QueryExecutionService GetPrimedExecutionService(Dictionary[][] data, bool isConnected, bool throwOnRead, WorkspaceService workspaceService, out Dictionary storage) { // Create a place for the temp "files" to be written storage = new Dictionary(); // Create the connection factory with the dataset var factory = CreateTestConnectionInfo(data, throwOnRead).Factory; // Mock the connection service var connectionService = new Mock(); ConnectionInfo ci = new ConnectionInfo(factory, OwnerUri, StandardConnectionDetails); ConnectionInfo outValMock; connectionService .Setup(service => service.TryFindConnection(It.IsAny(), out outValMock)) .OutCallback((string owner, out ConnectionInfo connInfo) => connInfo = isConnected ? ci : null) .Returns(isConnected); return new QueryExecutionService(connectionService.Object, workspaceService) { BufferFileStreamFactory = GetFileStreamFactory(storage) }; } public static QueryExecutionService GetPrimedExecutionService(Dictionary[][] data, bool isConnected, bool throwOnRead, WorkspaceService workspaceService) { Dictionary storage; return GetPrimedExecutionService(data, isConnected, throwOnRead, workspaceService, out storage); } public static WorkspaceService GetPrimedWorkspaceService(string query) { // Set up file for returning the query var fileMock = new Mock(); fileMock.SetupGet(file => file.Contents).Returns(query); // Set up workspace mock var workspaceService = new Mock>(); workspaceService.Setup(service => service.Workspace.GetFile(It.IsAny())) .Returns(fileMock.Object); return workspaceService.Object; } #endregion } }