// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; using System.IO; using System.Threading.Tasks; using System.Runtime.InteropServices; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Moq; using Xunit; namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution { /// /// Tests for saving a result set to a file /// public class SaveResultsTests { /// /// Test save results to a file as CSV with correct parameters /// [Fact] public void SaveResultsAsCsvSuccessTest() { // Execute a query var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); // Request to save the results as csv with correct parameters var saveParams = new SaveResultsRequestParams { OwnerUri = Common.OwnerUri, ResultSetIndex = 0, BatchIndex = 0 }; saveParams.FilePath = "testwrite.csv"; saveParams.IncludeHeaders = true; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); // Expect to see a file successfully created in filepath and a success message Assert.Equal("Success", result.Messages); Assert.True(File.Exists(saveParams.FilePath)); VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); // Delete temp file after test if(File.Exists(saveParams.FilePath)) { File.Delete(saveParams.FilePath); } } /// /// Test handling exception in saving results to file /// [Fact] public void SaveResultsAsCsvExceptionTest() { // Execute a query var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); // Request to save the results as csv with incorrect filepath var saveParams = new SaveResultsRequestParams { OwnerUri = Common.OwnerUri, ResultSetIndex = 0, BatchIndex = 0 }; if ( RuntimeInformation.IsOSPlatform( OSPlatform.Windows)) { saveParams.FilePath = "G:\\test.csv"; } else { saveParams.FilePath = "/test.csv"; } // SaveResultRequestResult result = null; String errMessage = null; var saveRequest = GetSaveResultsContextMock( null, err => errMessage = (String) err); queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); // Expect to see error message Assert.NotNull(errMessage); VerifySaveResultsCallCount(saveRequest, Times.Never(), Times.Once()); Assert.False(File.Exists(saveParams.FilePath)); } /// /// Test saving results to file when the requested result set is no longer active /// [Fact] public void SaveResultsAsCsvQueryNotFoundTest() { // Execute a query var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); var executeParams = new QueryExecuteParams { QueryText = Common.StandardQuery, OwnerUri = Common.OwnerUri }; var executeRequest = GetQueryExecuteResultContextMock(null, null, null); queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); // Request to save the results as csv with query that is no longer active var saveParams = new SaveResultsRequestParams { OwnerUri = "falseuri", ResultSetIndex = 0, BatchIndex = 0 }; saveParams.FilePath = "testwrite.csv"; SaveResultRequestResult result = null; var saveRequest = GetSaveResultsContextMock(qcr => result = qcr, null); // queryService.ActiveQueries[Common.OwnerUri].Batches[0] = Common.GetBasicExecutedBatch(); queryService.HandleSaveResultsAsCsvRequest(saveParams, saveRequest.Object).Wait(); // Expect message that save failed Assert.Equal("Failed to save results, ID not found.", result.Messages); Assert.False(File.Exists(saveParams.FilePath)); VerifySaveResultsCallCount(saveRequest, Times.Once(), Times.Never()); } #region Mocking /// /// Mock the requestContext for saving a result set /// /// /// /// private static Mock> GetSaveResultsContextMock( Action resultCallback, Action errorCallback) { var requestContext = new Mock>(); // Setup the mock for SendResult var sendResultFlow = requestContext .Setup(rc => rc.SendResult(It.IsAny ())) .Returns(Task.FromResult(0)); if (resultCallback != null) { sendResultFlow.Callback(resultCallback); } // Setup the mock for SendError var sendErrorFlow = requestContext .Setup(rc => rc.SendError(It.IsAny())) .Returns(Task.FromResult(0)); if (errorCallback != null) { sendErrorFlow.Callback(errorCallback); } return requestContext; } /// /// Verify the call count for sendResult and error /// /// /// /// private static void VerifySaveResultsCallCount(Mock> mock, Times sendResultCalls, Times sendErrorCalls) { mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); } /// /// Mock request context for executing a query /// /// /// /// /// /// public static Mock> GetQueryExecuteResultContextMock( Action resultCallback, Action, QueryExecuteCompleteParams> eventCallback, Action errorCallback) { var requestContext = new Mock>(); // Setup the mock for SendResult var sendResultFlow = requestContext .Setup(rc => rc.SendResult(It.IsAny())) .Returns(Task.FromResult(0)); if (resultCallback != null) { sendResultFlow.Callback(resultCallback); } // Setup the mock for SendEvent var sendEventFlow = requestContext.Setup(rc => rc.SendEvent( It.Is>(m => m == QueryExecuteCompleteEvent.Type), It.IsAny())) .Returns(Task.FromResult(0)); if (eventCallback != null) { sendEventFlow.Callback(eventCallback); } // Setup the mock for SendError var sendErrorFlow = requestContext.Setup(rc => rc.SendError(It.IsAny())) .Returns(Task.FromResult(0)); if (errorCallback != null) { sendErrorFlow.Callback(errorCallback); } return requestContext; } #endregion } }