diff --git a/sqltoolsservice.sln b/sqltoolsservice.sln index 828baca9..cd55b538 100644 --- a/sqltoolsservice.sln +++ b/sqltoolsservice.sln @@ -8,7 +8,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{AB9CA2B8-6 EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{32DC973E-9EEA-4694-B1C2-B031167AB945}" ProjectSection(SolutionItems) = preProject + .gitignore = .gitignore global.json = global.json + nuget.config = nuget.config + README.md = README.md EndProjectSection EndProject Project("{8BB2217D-0F2D-49D1-97BC-3654ED321F3B}") = "Microsoft.SqlTools.ServiceLayer", "src\Microsoft.SqlTools.ServiceLayer\Microsoft.SqlTools.ServiceLayer.xproj", "{0D61DC2B-DA66-441D-B9D0-F76C98F780F9}" diff --git a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs index 153e46d6..a2811f6a 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Hosting/Protocol/RequestContext.cs @@ -20,7 +20,9 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol this.messageWriter = messageWriter; } - public async Task SendResult(TResult resultDetails) + public RequestContext() { } + + public virtual async Task SendResult(TResult resultDetails) { await this.messageWriter.WriteResponse( resultDetails, @@ -28,14 +30,14 @@ namespace Microsoft.SqlTools.ServiceLayer.Hosting.Protocol requestMessage.Id); } - public async Task SendEvent(EventType eventType, TParams eventParams) + public virtual async Task SendEvent(EventType eventType, TParams eventParams) { await this.messageWriter.WriteEvent( eventType, eventParams); } - public async Task SendError(object errorDetails) + public virtual async Task SendError(object errorDetails) { await this.messageWriter.WriteMessage( Message.ResponseError( diff --git a/src/Microsoft.SqlTools.ServiceLayer/Program.cs b/src/Microsoft.SqlTools.ServiceLayer/Program.cs index f6054354..c0f547c2 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Program.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Program.cs @@ -2,12 +2,16 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Threading.Tasks; using Microsoft.SqlTools.EditorServices.Utility; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; namespace Microsoft.SqlTools.ServiceLayer { @@ -46,6 +50,7 @@ namespace Microsoft.SqlTools.ServiceLayer AutoCompleteService.Instance.InitializeService(serviceHost); LanguageService.Instance.InitializeService(serviceHost, sqlToolsContext); ConnectionService.Instance.InitializeService(serviceHost); + QueryExecutionService.Instance.InitializeService(serviceHost); serviceHost.Initialize(); serviceHost.WaitForExit(); diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs new file mode 100644 index 00000000..70e6631c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryDisposeRequest.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query dispose request + /// + public class QueryDisposeParams + { + public string OwnerUri { get; set; } + } + + /// + /// Parameters to return as the result of a query dispose request + /// + public class QueryDisposeResult + { + /// + /// Any error messages that occurred during disposing the result set. Optional, can be set + /// to null if there were no errors. + /// + public string Messages { get; set; } + } + + public class QueryDisposeRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/dispose"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs new file mode 100644 index 00000000..f81edb62 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteCompleteNotification.cs @@ -0,0 +1,42 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters to be sent back with a query execution complete event + /// + public class QueryExecuteCompleteParams + { + /// + /// URI for the editor that owns the query + /// + public string OwnerUri { get; set; } + + /// + /// Any messages that came back from the server during execution of the query + /// + public string[] Messages { get; set; } + + /// + /// Whether or not the query was successful. True indicates errors, false indicates success + /// + public bool HasError { get; set; } + + /// + /// Summaries of the result sets that were returned with the query + /// + public ResultSetSummary[] ResultSetSummaries { get; set; } + } + + public class QueryExecuteCompleteEvent + { + public static readonly + EventType Type = + EventType.Create("query/complete"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs new file mode 100644 index 00000000..cac98c1a --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteRequest.cs @@ -0,0 +1,43 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for the query execute request + /// + public class QueryExecuteParams + { + /// + /// The text of the query to execute + /// + public string QueryText { get; set; } + + /// + /// URI for the editor that is asking for the query execute + /// + public string OwnerUri { get; set; } + } + + /// + /// Parameters for the query execute result + /// + public class QueryExecuteResult + { + /// + /// Connection error messages. Optional, can be set to null to indicate no errors + /// + public string Messages { get; set; } + } + + public class QueryExecuteRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/execute"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteSubsetRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteSubsetRequest.cs new file mode 100644 index 00000000..cdf434bb --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/QueryExecuteSubsetRequest.cs @@ -0,0 +1,61 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Parameters for a query result subset retrieval request + /// + public class QueryExecuteSubsetParams + { + /// + /// URI for the file that owns the query to look up the results for + /// + public string OwnerUri { get; set; } + + /// + /// Index of the result set to get the results from + /// + public int ResultSetIndex { get; set; } + + /// + /// Beginning index of the rows to return from the selected resultset. This index will be + /// included in the results. + /// + public int RowsStartIndex { get; set; } + + /// + /// Number of rows to include in the result of this request. If the number of the rows + /// exceeds the number of rows available after the start index, all available rows after + /// the start index will be returned. + /// + public int RowsCount { get; set; } + } + + /// + /// Parameters for the result of a subset retrieval request + /// + public class QueryExecuteSubsetResult + { + /// + /// Subset request error messages. Optional, can be set to null to indicate no errors + /// + public string Message { get; set; } + + /// + /// The requested subset of results. Optional, can be set to null to indicate an error + /// + public ResultSetSubset ResultSubset { get; set; } + } + + public class QueryExecuteSubsetRequest + { + public static readonly + RequestType Type = + RequestType.Create("query/subset"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs new file mode 100644 index 00000000..8e2b49a9 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSubset.cs @@ -0,0 +1,24 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Class used to represent a subset of results from a query for transmission across JSON RPC + /// + public class ResultSetSubset + { + /// + /// The number of rows returned from result set, useful for determining if less rows were + /// returned than requested. + /// + public int RowCount { get; set; } + + /// + /// 2D array of the cell values requested from result set + /// + public object[][] Rows { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs new file mode 100644 index 00000000..5f8de12a --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Contracts/ResultSetSummary.cs @@ -0,0 +1,30 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Data.Common; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts +{ + /// + /// Represents a summary of information about a result without returning any cells of the results + /// + public class ResultSetSummary + { + /// + /// The ID of the result set within the query results + /// + public int Id { get; set; } + + /// + /// The number of rows that was returned with the resultset + /// + public int RowCount { get; set; } + + /// + /// Details about the columns that are provided as solutions + /// + public DbColumn[] ColumnInfo { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs new file mode 100644 index 00000000..434188a5 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/Query.cs @@ -0,0 +1,234 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + /// + /// Internal representation of an active query + /// + public class Query : IDisposable + { + #region Properties + + /// + /// Cancellation token source, used for cancelling async db actions + /// + private readonly CancellationTokenSource cancellationSource; + + /// + /// The connection info associated with the file editor owner URI, used to create a new + /// connection upon execution of the query + /// + public ConnectionInfo EditorConnection { get; set; } + + public bool HasExecuted { get; set; } + + /// + /// The text of the query to execute + /// + public string QueryText { get; set; } + + /// + /// The result sets of the query execution + /// + public List ResultSets { get; set; } + + /// + /// Property for generating a set result set summaries from the result sets + /// + public ResultSetSummary[] ResultSummary + { + get + { + return ResultSets.Select((set, index) => new ResultSetSummary + { + ColumnInfo = set.Columns, + Id = index, + RowCount = set.Rows.Count + }).ToArray(); + } + } + + #endregion + + /// + /// Constructor for a query + /// + /// The text of the query to execute + /// The information of the connection to use to execute the query + public Query(string queryText, ConnectionInfo connection) + { + // Sanity check for input + if (String.IsNullOrWhiteSpace(queryText)) + { + throw new ArgumentNullException(nameof(queryText), "Query text cannot be null"); + } + if (connection == null) + { + throw new ArgumentNullException(nameof(connection), "Connection cannot be null"); + } + + // Initialize the internal state + QueryText = queryText; + EditorConnection = connection; + HasExecuted = false; + ResultSets = new List(); + cancellationSource = new CancellationTokenSource(); + } + + /// + /// Executes this query asynchronously and collects all result sets + /// + public async Task Execute() + { + // Sanity check to make sure we haven't already run this query + if (HasExecuted) + { + throw new InvalidOperationException("Query has already executed."); + } + + DbConnection conn = null; + + // Create a connection from the connection details + try + { + string connectionString = ConnectionService.BuildConnectionString(EditorConnection.ConnectionDetails); + using (EditorConnection.Factory.CreateSqlConnection(connectionString)) + { + await conn.OpenAsync(cancellationSource.Token); + + // Create a command that we'll use for executing the query + using (DbCommand command = conn.CreateCommand()) + { + command.CommandText = QueryText; + command.CommandType = CommandType.Text; + + // Execute the command to get back a reader + using (DbDataReader reader = await command.ExecuteReaderAsync(cancellationSource.Token)) + { + do + { + // TODO: This doesn't properly handle scenarios where the query is SELECT but does not have rows + if (!reader.HasRows) + { + continue; + } + + // Read until we hit the end of the result set + ResultSet resultSet = new ResultSet(); + while (await reader.ReadAsync(cancellationSource.Token)) + { + resultSet.AddRow(reader); + } + + // Read off the column schema information + if (reader.CanGetColumnSchema()) + { + resultSet.Columns = reader.GetColumnSchema().ToArray(); + } + + // Add the result set to the results of the query + ResultSets.Add(resultSet); + } while (await reader.NextResultAsync(cancellationSource.Token)); + } + } + } + } + catch (Exception) + { + // Dispose of the connection + conn?.Dispose(); + } + finally + { + // Mark that we have executed + HasExecuted = true; + } + } + + /// + /// Retrieves a subset of the result sets + /// + /// The index for selecting the result set + /// The starting row of the results + /// How many rows to retrieve + /// A subset of results + public ResultSetSubset GetSubset(int resultSetIndex, int startRow, int rowCount) + { + // Sanity check that the results are available + if (!HasExecuted) + { + throw new InvalidOperationException("The query has not completed, yet."); + } + + // Sanity check to make sure we have valid numbers + if (resultSetIndex < 0 || resultSetIndex >= ResultSets.Count) + { + throw new ArgumentOutOfRangeException(nameof(resultSetIndex), "Result set index cannot be less than 0" + + "or greater than the number of result sets"); + } + ResultSet targetResultSet = ResultSets[resultSetIndex]; + if (startRow < 0 || startRow >= targetResultSet.Rows.Count) + { + throw new ArgumentOutOfRangeException(nameof(startRow), "Start row cannot be less than 0 " + + "or greater than the number of rows in the resultset"); + } + if (rowCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(rowCount), "Row count must be a positive integer"); + } + + // Retrieve the subset of the results as per the request + object[][] rows = targetResultSet.Rows.Skip(startRow).Take(rowCount).ToArray(); + return new ResultSetSubset + { + Rows = rows, + RowCount = rows.Length + }; + } + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + cancellationSource.Dispose(); + } + + disposed = true; + } + + ~Query() + { + Dispose(false); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs new file mode 100644 index 00000000..6480e4ba --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs @@ -0,0 +1,309 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Concurrent; +using System.Data.Common; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + /// + /// Service for executing queries + /// + public sealed class QueryExecutionService : IDisposable + { + #region Singleton Instance Implementation + + private static readonly Lazy instance = new Lazy(() => new QueryExecutionService()); + + public static QueryExecutionService Instance + { + get { return instance.Value; } + } + + private QueryExecutionService() + { + ConnectionService = ConnectionService.Instance; + } + + internal QueryExecutionService(ConnectionService connService) + { + ConnectionService = connService; + } + + #endregion + + #region Properties + + /// + /// The collection of active queries + /// + internal ConcurrentDictionary ActiveQueries + { + get { return queries.Value; } + } + + /// + /// Instance of the connection service, used to get the connection info for a given owner URI + /// + private ConnectionService ConnectionService { get; set; } + + /// + /// Internal storage of active queries, lazily constructed as a threadsafe dictionary + /// + private readonly Lazy> queries = + new Lazy>(() => new ConcurrentDictionary()); + + #endregion + + /// + /// Initializes the service with the service host, registers request handlers and shutdown + /// event handler. + /// + /// The service host instance to register with + public void InitializeService(ServiceHost serviceHost) + { + // Register handlers for requests + serviceHost.SetRequestHandler(QueryExecuteRequest.Type, HandleExecuteRequest); + serviceHost.SetRequestHandler(QueryExecuteSubsetRequest.Type, HandleResultSubsetRequest); + serviceHost.SetRequestHandler(QueryDisposeRequest.Type, HandleDisposeRequest); + + // Register handler for shutdown event + serviceHost.RegisterShutdownTask((shutdownParams, requestContext) => + { + Dispose(); + return Task.FromResult(0); + }); + } + + #region Request Handlers + + public async Task HandleExecuteRequest(QueryExecuteParams executeParams, + RequestContext requestContext) + { + try + { + // Get a query new active query + Query newQuery = await CreateAndActivateNewQuery(executeParams, requestContext); + + // Execute the query + await ExecuteAndCompleteQuery(executeParams, requestContext, newQuery); + } + catch (Exception e) + { + // Dump any unexpected exceptions as errors + await requestContext.SendError(e.Message); + } + } + + public async Task HandleResultSubsetRequest(QueryExecuteSubsetParams subsetParams, + RequestContext requestContext) + { + try + { + // Attempt to load the query + Query query; + if (!ActiveQueries.TryGetValue(subsetParams.OwnerUri, out query)) + { + await requestContext.SendResult(new QueryExecuteSubsetResult + { + Message = "The requested query does not exist." + }); + return; + } + + // Retrieve the requested subset and return it + var result = new QueryExecuteSubsetResult + { + Message = null, + ResultSubset = query.GetSubset( + subsetParams.ResultSetIndex, subsetParams.RowsStartIndex, subsetParams.RowsCount) + }; + await requestContext.SendResult(result); + } + catch (InvalidOperationException ioe) + { + // Return the error as a result + await requestContext.SendResult(new QueryExecuteSubsetResult + { + Message = ioe.Message + }); + } + catch (ArgumentOutOfRangeException aoore) + { + // Return the error as a result + await requestContext.SendResult(new QueryExecuteSubsetResult + { + Message = aoore.Message + }); + } + catch (Exception e) + { + // This was unexpected, so send back as error + await requestContext.SendError(e.Message); + } + } + + public async Task HandleDisposeRequest(QueryDisposeParams disposeParams, + RequestContext requestContext) + { + try + { + // Attempt to remove the query for the owner uri + Query result; + if (!ActiveQueries.TryRemove(disposeParams.OwnerUri, out result)) + { + await requestContext.SendResult(new QueryDisposeResult + { + Messages = "Failed to dispose query, ID not found." + }); + return; + } + + // Success + await requestContext.SendResult(new QueryDisposeResult + { + Messages = null + }); + } + catch (Exception e) + { + await requestContext.SendError(e.Message); + } + } + + #endregion + + #region Private Helpers + + private async Task CreateAndActivateNewQuery(QueryExecuteParams executeParams, RequestContext requestContext) + { + try + { + // Attempt to get the connection for the editor + ConnectionInfo connectionInfo; + if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connectionInfo)) + { + await requestContext.SendResult(new QueryExecuteResult + { + Messages = "This editor is not connected to a database." + }); + return null; + } + + // Attempt to clean out any old query on the owner URI + Query oldQuery; + if (ActiveQueries.TryGetValue(executeParams.OwnerUri, out oldQuery) && oldQuery.HasExecuted) + { + ActiveQueries.TryRemove(executeParams.OwnerUri, out oldQuery); + } + + // If we can't add the query now, it's assumed the query is in progress + Query newQuery = new Query(executeParams.QueryText, connectionInfo); + if (!ActiveQueries.TryAdd(executeParams.OwnerUri, newQuery)) + { + await requestContext.SendResult(new QueryExecuteResult + { + Messages = "A query is already in progress for this editor session." + + "Please cancel this query or wait for its completion." + }); + return null; + } + + return newQuery; + } + catch (ArgumentNullException ane) + { + await requestContext.SendResult(new QueryExecuteResult { Messages = ane.Message }); + return null; + } + // Any other exceptions will fall through here and be collected at the end + } + + private async Task ExecuteAndCompleteQuery(QueryExecuteParams executeParams, RequestContext requestContext, Query query) + { + // Skip processing if the query is null + if (query == null) + { + return; + } + + // Launch the query and respond with successfully launching it + Task executeTask = query.Execute(); + await requestContext.SendResult(new QueryExecuteResult + { + Messages = null + }); + + try + { + // Wait for query execution and then send back the results + await Task.WhenAll(executeTask); + QueryExecuteCompleteParams eventParams = new QueryExecuteCompleteParams + { + HasError = false, + Messages = new string[] { }, // TODO: Figure out how to get messages back from the server + OwnerUri = executeParams.OwnerUri, + ResultSetSummaries = query.ResultSummary + }; + await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, eventParams); + } + catch (DbException dbe) + { + // Dump the message to a complete event + QueryExecuteCompleteParams errorEvent = new QueryExecuteCompleteParams + { + HasError = true, + Messages = new[] {dbe.Message}, + OwnerUri = executeParams.OwnerUri, + ResultSetSummaries = query.ResultSummary + }; + await requestContext.SendEvent(QueryExecuteCompleteEvent.Type, errorEvent); + } + } + + #endregion + + #region IDisposable Implementation + + private bool disposed; + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (disposed) + { + return; + } + + if (disposing) + { + foreach (var query in ActiveQueries) + { + query.Value.Dispose(); + } + } + + disposed = true; + } + + ~QueryExecutionService() + { + Dispose(false); + } + + #endregion + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs new file mode 100644 index 00000000..fed08ea3 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/ResultSet.cs @@ -0,0 +1,37 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.Data.Common; + +namespace Microsoft.SqlTools.ServiceLayer.QueryExecution +{ + public class ResultSet + { + public DbColumn[] Columns { get; set; } + + public List Rows { get; private set; } + + public ResultSet() + { + Rows = new List(); + } + + /// + /// Add a row of data to the result set using a that has already + /// read in a row. + /// + /// A that has already had a read performed + public void AddRow(DbDataReader reader) + { + List row = new List(); + for (int i = 0; i < reader.FieldCount; ++i) + { + row.Add(reader.GetValue(i)); + } + Rows.Add(row.ToArray()); + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs index 560805d7..3099a3d5 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Workspace/Workspace.cs @@ -124,6 +124,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Workspace // type SqlTools have a path starting with 'untitled'. return filePath.StartsWith("inmemory") || + filePath.StartsWith("tsqloutput") || filePath.StartsWith("untitled"); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs new file mode 100644 index 00000000..9bc8053b --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/Common.cs @@ -0,0 +1,181 @@ +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; +using Moq; +using Moq.Protected; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class Common + { + public const string OwnerUri = "testFile"; + + public static readonly Dictionary[] StandardTestData = + { + new Dictionary { {"col1", "val11"}, { "col2", "val12"}, { "col3", "val13"}, { "col4", "col14"} }, + new Dictionary { {"col1", "val21"}, { "col2", "val22"}, { "col3", "val23"}, { "col4", "col24"} }, + new Dictionary { {"col1", "val31"}, { "col2", "val32"}, { "col3", "val33"}, { "col4", "col34"} }, + new Dictionary { {"col1", "val41"}, { "col2", "val42"}, { "col3", "val43"}, { "col4", "col44"} }, + new Dictionary { {"col1", "val51"}, { "col2", "val52"}, { "col3", "val53"}, { "col4", "col54"} }, + }; + + public static Dictionary[] GetTestData(int columns, int rows) + { + Dictionary[] output = new Dictionary[rows]; + for (int row = 0; row < rows; row++) + { + Dictionary rowDictionary = new Dictionary(); + for (int column = 0; column < columns; column++) + { + rowDictionary.Add(String.Format("column{0}", column), String.Format("val{0}{1}", column, row)); + } + output[row] = rowDictionary; + } + + return output; + } + + public static Query GetBasicExecutedQuery() + { + Query query = new Query("SIMPLE QUERY", CreateTestConnectionInfo(new[] { StandardTestData }, false)); + query.Execute().Wait(); + return query; + } + + #region DbConnection Mocking + + public static DbCommand CreateTestCommand(Dictionary[][] data, bool throwOnRead) + { + var commandMock = new Mock { CallBase = true }; + var commandMockSetup = commandMock.Protected() + .Setup("ExecuteDbDataReader", It.IsAny()); + + // Setup the expected behavior + if (throwOnRead) + { + commandMockSetup.Throws(new Mock().Object); + } + else + { + commandMockSetup.Returns(new TestDbDataReader(data)); + } + + + return commandMock.Object; + } + + public static DbConnection CreateTestConnection(Dictionary[][] data, bool throwOnRead) + { + var connectionMock = new Mock { CallBase = true }; + connectionMock.Protected() + .Setup("CreateDbCommand") + .Returns(CreateTestCommand(data, throwOnRead)); + + return connectionMock.Object; + } + + public static ISqlConnectionFactory CreateMockFactory(Dictionary[][] data, bool throwOnRead) + { + var mockFactory = new Mock(); + mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny())) + .Returns(CreateTestConnection(data, throwOnRead)); + + return mockFactory.Object; + } + + public static ConnectionInfo CreateTestConnectionInfo(Dictionary[][] data, bool throwOnRead) + { + // Create connection info + ConnectionDetails connDetails = new ConnectionDetails + { + UserName = "sa", + Password = "Yukon900", + DatabaseName = "AdventureWorks2016CTP3_2", + ServerName = "sqltools11" + }; + + return new ConnectionInfo(CreateMockFactory(data, throwOnRead), "test://test", connDetails); + } + + #endregion + + #region Service Mocking + + public static ConnectionDetails GetTestConnectionDetails() + { + return new ConnectionDetails + { + DatabaseName = "123", + Password = "456", + ServerName = "789", + UserName = "012" + }; + } + + public static QueryExecutionService GetPrimedExecutionService(ISqlConnectionFactory factory, bool isConnected) + { + var connectionService = new ConnectionService(factory); + if (isConnected) + { + connectionService.Connect(new ConnectParams + { + Connection = GetTestConnectionDetails(), + OwnerUri = OwnerUri + }); + } + return new QueryExecutionService(connectionService); + } + + #endregion + + #region Request Mocking + + public static Mock> GetQueryExecuteResultContextMock( + Action resultCallback, + Action, QueryExecuteCompleteParams> eventCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendEvent + var sendEventFlow = requestContext.Setup(rc => rc.SendEvent( + It.Is>(m => m == QueryExecuteCompleteEvent.Type), + It.IsAny())) + .Returns(Task.FromResult(0)); + if (eventCallback != null) + { + sendEventFlow.Callback(eventCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext.Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs new file mode 100644 index 00000000..c0fed697 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/DisposeTests.cs @@ -0,0 +1,93 @@ +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class DisposeTests + { + [Fact] + public void DisposeExecutedQuery() + { + // If: + // ... I request a query (doesn't matter what kind) + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams {QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri}; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And then I dispose of the query + var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; + QueryDisposeResult result = null; + var disposeRequest = GetQueryDisposeResultContextMock(qdr => result = qdr, null); + queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object).Wait(); + + // Then: + // ... I should have seen a successful result + // ... And the active queries should be empty + VerifyQueryDisposeCallCount(disposeRequest, Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void QueryDisposeMissingQuery() + { + // If: + // ... I attempt to dispose a query that doesn't exist + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var disposeParams = new QueryDisposeParams {OwnerUri = Common.OwnerUri}; + QueryDisposeResult result = null; + var disposeRequest = GetQueryDisposeResultContextMock(qdr => result = qdr, null); + queryService.HandleDisposeRequest(disposeParams, disposeRequest.Object).Wait(); + + // Then: + // ... I should have gotten an error result + VerifyQueryDisposeCallCount(disposeRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + } + + #region Mocking + + private Mock> GetQueryDisposeResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private void VerifyQueryDisposeCallCount(Mock> mock, Times sendResultCalls, + Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs new file mode 100644 index 00000000..cddf1831 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/ExecuteTests.cs @@ -0,0 +1,386 @@ +using System; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class ExecuteTests + { + #region Query Class Tests + + [Fact] + public void QueryCreationTest() + { + // If I create a new query... + Query query = new Query("NO OP", Common.CreateTestConnectionInfo(null, false)); + + // Then: + // ... It should not have executed + Assert.False(query.HasExecuted, "The query should not have executed."); + + // ... The results should be empty + Assert.Empty(query.ResultSets); + Assert.Empty(query.ResultSummary); + } + + [Fact] + public void QueryExecuteNoResultSets() + { + // If I execute a query that should get no result sets + Query query = new Query("Query with no result sets", Common.CreateTestConnectionInfo(null, false)); + query.Execute().Wait(); + + // Then: + // ... It should have executed + Assert.True(query.HasExecuted, "The query should have been marked executed."); + + // ... The results should be empty + Assert.Empty(query.ResultSets); + Assert.Empty(query.ResultSummary); + + // ... The results should not be null + Assert.NotNull(query.ResultSets); + Assert.NotNull(query.ResultSummary); + } + + [Fact] + public void QueryExecuteQueryOneResultSet() + { + ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] {Common.StandardTestData}, false); + + // If I execute a query that should get one result set + int resultSets = 1; + int rows = 5; + int columns = 4; + Query query = new Query("Query with one result sets", ci); + query.Execute().Wait(); + + // Then: + // ... It should have executed + Assert.True(query.HasExecuted, "The query should have been marked executed."); + + // ... There should be exactly one result set + Assert.Equal(resultSets, query.ResultSets.Count); + + // ... Inside the result set should be with 5 rows + Assert.Equal(rows, query.ResultSets[0].Rows.Count); + + // ... Inside the result set should have 5 columns and 5 column definitions + Assert.Equal(columns, query.ResultSets[0].Rows[0].Length); + Assert.Equal(columns, query.ResultSets[0].Columns.Length); + + // ... There should be exactly one result set summary + Assert.Equal(resultSets, query.ResultSummary.Length); + + // ... Inside the result summary, there should be 5 column definitions + Assert.Equal(columns, query.ResultSummary[0].ColumnInfo.Length); + + // ... Inside the result summary, there should be 5 rows + Assert.Equal(rows, query.ResultSummary[0].RowCount); + } + + [Fact] + public void QueryExecuteQueryTwoResultSets() + { + var dataset = new[] {Common.StandardTestData, Common.StandardTestData}; + int resultSets = dataset.Length; + int rows = Common.StandardTestData.Length; + int columns = Common.StandardTestData[0].Count; + ConnectionInfo ci = Common.CreateTestConnectionInfo(dataset, false); + + // If I execute a query that should get two result sets + Query query = new Query("Query with two result sets", ci); + query.Execute().Wait(); + + // Then: + // ... It should have executed + Assert.True(query.HasExecuted, "The query should have been marked executed."); + + // ... There should be exactly two result sets + Assert.Equal(resultSets, query.ResultSets.Count); + + foreach (ResultSet rs in query.ResultSets) + { + // ... Each result set should have 5 rows + Assert.Equal(rows, rs.Rows.Count); + + // ... Inside each result set should be 5 columns and 5 column definitions + Assert.Equal(columns, rs.Rows[0].Length); + Assert.Equal(columns, rs.Columns.Length); + } + + // ... There should be exactly two result set summaries + Assert.Equal(resultSets, query.ResultSummary.Length); + + foreach (ResultSetSummary rs in query.ResultSummary) + { + // ... Inside each result summary, there should be 5 column definitions + Assert.Equal(columns, rs.ColumnInfo.Length); + + // ... Inside each result summary, there should be 5 rows + Assert.Equal(rows, rs.RowCount); + } + } + + [Fact] + public void QueryExecuteInvalidQuery() + { + ConnectionInfo ci = Common.CreateTestConnectionInfo(null, true); + + // If I execute a query that is invalid + Query query = new Query("Invalid query", ci); + + // Then: + // ... It should throw an exception + Exception e = Assert.Throws(() => query.Execute().Wait()); + } + + [Fact] + public void QueryExecuteExecutedQuery() + { + ConnectionInfo ci = Common.CreateTestConnectionInfo(new[] {Common.StandardTestData}, false); + + // If I execute a query + Query query = new Query("Any query", ci); + query.Execute().Wait(); + + // Then: + // ... It should have executed + Assert.True(query.HasExecuted, "The query should have been marked executed."); + + // If I execute it again + // Then: + // ... It should throw an invalid operation exception wrapped in an aggregate exception + AggregateException ae = Assert.Throws(() => query.Execute().Wait()); + Assert.Equal(1, ae.InnerExceptions.Count); + Assert.IsType(ae.InnerExceptions[0]); + + // ... The data should still be available + Assert.True(query.HasExecuted, "The query should still be marked executed."); + Assert.NotEmpty(query.ResultSets); + Assert.NotEmpty(query.ResultSummary); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + [InlineData(null)] + public void QueryExecuteNoQuery(string query) + { + // If: + // ... I create a query that has an empty query + // Then: + // ... It should throw an exception + Assert.Throws(() => new Query(query, null)); + } + + [Fact] + public void QueryExecuteNoConnectionInfo() + { + // If: + // ... I create a query that has a null connection info + // Then: + // ... It should throw an exception + Assert.Throws(() => new Query("Some Query", null)); + } + + #endregion + + #region Service Tests + + [Fact] + public void QueryExecuteValidNoResultsTest() + { + // If: + // ... I request to execute a valid query with no results + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { QueryText = "Doesn't Matter", OwnerUri = Common.OwnerUri }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, cp) => completeParams = cp, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No Errors should have been sent + // ... A successful result should have been sent with no messages + // ... A completion event should have been fired with empty results + // ... There should be one active query + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.Empty(completeParams.Messages); + Assert.Empty(completeParams.ResultSetSummaries); + Assert.False(completeParams.HasError); + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public void QueryExecuteValidResultsTest() + { + // If: + // ... I request to execute a valid query with results + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = "Doesn't Matter" }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams completeParams = null; + var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, cp) => completeParams = cp, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A successful result should have been sent with no messages + // ... A completion event should have been fired with one result + // ... There should be one active query + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.Empty(completeParams.Messages); + Assert.NotEmpty(completeParams.ResultSetSummaries); + Assert.False(completeParams.HasError); + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public void QueryExecuteUnconnectedUriTest() + { + // If: + // ... I request to execute a query using a file URI that isn't connected + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), false); + var queryParams = new QueryExecuteParams { OwnerUri = "notConnected", QueryText = "Doesn't Matter" }; + + QueryExecuteResult result = null; + var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, null, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... An error message should have been returned via the result + // ... No completion event should have been fired + // ... No error event should have been fired + // ... There should be no active queries + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + Assert.Empty(queryService.ActiveQueries); + } + + [Fact] + public void QueryExecuteInProgressTest() + { + // If: + // ... I request to execute a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = "Some Query" }; + + // Note, we don't care about the results of the first request + var firstRequestContext = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + + // ... And then I request another query without waiting for the first to complete + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; // Simulate query hasn't finished + QueryExecuteResult result = null; + var secondRequestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, null, null); + queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with an error message + // ... No completion event should have been fired + // ... There should only be one active query + VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.AtMostOnce(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Fact] + public void QueryExecuteCompletedTest() + { + // If: + // ... I request to execute a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = "Some Query" }; + + // Note, we don't care about the results of the first request + var firstRequestContext = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(queryParams, firstRequestContext.Object).Wait(); + + // ... And then I request another query after waiting for the first to complete + QueryExecuteResult result = null; + QueryExecuteCompleteParams complete = null; + var secondRequestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, qecp) => complete = qecp, null); + queryService.HandleExecuteRequest(queryParams, secondRequestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with no errors + // ... There should only be one active query + VerifyQueryExecuteCallCount(secondRequestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.False(complete.HasError); + Assert.Equal(1, queryService.ActiveQueries.Count); + } + + [Theory] + [InlineData("")] + [InlineData(" ")] + [InlineData(null)] + public void QueryExecuteMissingQueryTest(string query) + { + // If: + // ... I request to execute a query with a missing query string + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = query }; + + QueryExecuteResult result = null; + var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, null, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with an error message + // ... No completion event should have been fired + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Never(), Times.Never()); + Assert.NotNull(result.Messages); + Assert.NotEmpty(result.Messages); + } + + [Fact] + public void QueryExecuteInvalidQueryTest() + { + // If: + // ... I request to execute a query that is invalid + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, true), true); + var queryParams = new QueryExecuteParams { OwnerUri = Common.OwnerUri, QueryText = "Bad query!" }; + + QueryExecuteResult result = null; + QueryExecuteCompleteParams complete = null; + var requestContext = Common.GetQueryExecuteResultContextMock(qer => result = qer, (et, qecp) => complete = qecp, null); + queryService.HandleExecuteRequest(queryParams, requestContext.Object).Wait(); + + // Then: + // ... No errors should have been sent + // ... A result should have been sent with success (we successfully started the query) + // ... A completion event should have been sent with error + VerifyQueryExecuteCallCount(requestContext, Times.Once(), Times.Once(), Times.Never()); + Assert.Null(result.Messages); + Assert.True(complete.HasError); + Assert.NotEmpty(complete.Messages); + } + + #endregion + + private void VerifyQueryExecuteCallCount(Mock> mock, Times sendResultCalls, Times sendEventCalls, Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendEvent( + It.Is>(m => m == QueryExecuteCompleteEvent.Type), + It.IsAny()), sendEventCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs new file mode 100644 index 00000000..bdb0dc48 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/QueryExecution/SubsetTests.cs @@ -0,0 +1,206 @@ +using System; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.QueryExecution; +using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; +using Moq; +using Xunit; + +namespace Microsoft.SqlTools.ServiceLayer.Test.QueryExecution +{ + public class SubsetTests + { + #region Query Class Tests + + [Theory] + [InlineData(2)] + [InlineData(20)] + public void SubsetValidTest(int rowCount) + { + // If I have an executed query + Query q = Common.GetBasicExecutedQuery(); + + // ... And I ask for a subset with valid arguments + ResultSetSubset subset = q.GetSubset(0, 0, rowCount); + + // Then: + // I should get the requested number of rows + Assert.Equal(Math.Min(rowCount, Common.StandardTestData.Length), subset.RowCount); + Assert.Equal(Math.Min(rowCount, Common.StandardTestData.Length), subset.Rows.Length); + } + + [Fact] + public void SubsetUnexecutedQueryTest() + { + // If I have a query that has *not* been executed + Query q = new Query("NO OP", Common.CreateTestConnectionInfo(null, false)); + + // ... And I ask for a subset with valid arguments + // Then: + // ... It should throw an exception + Assert.Throws(() => q.GetSubset(0, 0, 2)); + } + + [Theory] + [InlineData(-1, 0, 2)] // Invalid result set, too low + [InlineData(2, 0, 2)] // Invalid result set, too high + [InlineData(0, -1, 2)] // Invalid start index, too low + [InlineData(0, 10, 2)] // Invalid start index, too high + [InlineData(0, 0, -1)] // Invalid row count, too low + [InlineData(0, 0, 0)] // Invalid row count, zero + public void SubsetInvalidParamsTest(int resultSetIndex, int rowStartInex, int rowCount) + { + // If I have an executed query + Query q = Common.GetBasicExecutedQuery(); + + // ... And I ask for a subset with an invalid result set index + // Then: + // ... It should throw an exception + Assert.Throws(() => q.GetSubset(resultSetIndex, rowStartInex, rowCount)); + } + + #endregion + + #region Service Intergration Tests + + [Fact] + public void SubsetServiceValidTest() + { + // If: + // ... I have a query that has results (doesn't matter what) + var queryService =Common.GetPrimedExecutionService( + Common.CreateMockFactory(new[] {Common.StandardTestData}, false), true); + var executeParams = new QueryExecuteParams {QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri}; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And I then ask for a valid set of results from it + var subsetParams = new QueryExecuteSubsetParams {OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0}; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should have a successful result + // ... There should be rows there (other test validate that the rows are correct) + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.Null(result.Message); + Assert.NotNull(result.ResultSubset); + } + + [Fact] + public void SubsetServiceMissingQueryTest() + { + // If: + // ... I ask for a set of results for a file that hasn't executed a query + var queryService = Common.GetPrimedExecutionService(Common.CreateMockFactory(null, false), true); + var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should have an error result + // ... There should be no rows in the result set + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Message); + Assert.Null(result.ResultSubset); + } + + [Fact] + public void SubsetServiceUnexecutedQueryTest() + { + // If: + // ... I have a query that hasn't finished executing (doesn't matter what) + var queryService = Common.GetPrimedExecutionService( + Common.CreateMockFactory(new[] { Common.StandardTestData }, false), true); + var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri }; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + queryService.ActiveQueries[Common.OwnerUri].HasExecuted = false; + + // ... And I then ask for a valid set of results from it + var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should get an error result + // ... There should not be rows + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Message); + Assert.Null(result.ResultSubset); + } + + [Fact] + public void SubsetServiceOutOfRangeSubsetTest() + { + // If: + // ... I have a query that doesn't have any result sets + var queryService = Common.GetPrimedExecutionService( + Common.CreateMockFactory(null, false), true); + var executeParams = new QueryExecuteParams { QueryText = "Doesn'tMatter", OwnerUri = Common.OwnerUri }; + var executeRequest = Common.GetQueryExecuteResultContextMock(null, null, null); + queryService.HandleExecuteRequest(executeParams, executeRequest.Object).Wait(); + + // ... And I then ask for a set of results from it + var subsetParams = new QueryExecuteSubsetParams { OwnerUri = Common.OwnerUri, RowsCount = 1, ResultSetIndex = 0, RowsStartIndex = 0 }; + QueryExecuteSubsetResult result = null; + var subsetRequest = GetQuerySubsetResultContextMock(qesr => result = qesr, null); + queryService.HandleResultSubsetRequest(subsetParams, subsetRequest.Object).Wait(); + + // Then: + // ... I should get an error result + // ... There should not be rows + // ... There should not be any error calls + VerifyQuerySubsetCallCount(subsetRequest, Times.Once(), Times.Never()); + Assert.NotNull(result.Message); + Assert.Null(result.ResultSubset); + } + + #endregion + + #region Mocking + + private Mock> GetQuerySubsetResultContextMock( + Action resultCallback, + Action errorCallback) + { + var requestContext = new Mock>(); + + // Setup the mock for SendResult + var sendResultFlow = requestContext + .Setup(rc => rc.SendResult(It.IsAny())) + .Returns(Task.FromResult(0)); + if (resultCallback != null) + { + sendResultFlow.Callback(resultCallback); + } + + // Setup the mock for SendError + var sendErrorFlow = requestContext + .Setup(rc => rc.SendError(It.IsAny())) + .Returns(Task.FromResult(0)); + if (errorCallback != null) + { + sendErrorFlow.Callback(errorCallback); + } + + return requestContext; + } + + private void VerifyQuerySubsetCallCount(Mock> mock, Times sendResultCalls, + Times sendErrorCalls) + { + mock.Verify(rc => rc.SendResult(It.IsAny()), sendResultCalls); + mock.Verify(rc => rc.SendError(It.IsAny()), sendErrorCalls); + } + + #endregion + + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs new file mode 100644 index 00000000..69edef72 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestDbDataReader.cs @@ -0,0 +1,207 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Data.Common; +using System.Linq; +using Moq; + +namespace Microsoft.SqlTools.ServiceLayer.Test.Utility +{ + public class TestDbDataReader : DbDataReader, IDbColumnSchemaGenerator + { + + #region Test Specific Implementations + + private Dictionary[][] Data { get; set; } + + public IEnumerator[]> ResultSet { get; private set; } + + private IEnumerator> Rows { get; set; } + + public TestDbDataReader(Dictionary[][] data) + { + Data = data; + if (Data != null) + { + ResultSet = ((IEnumerable[]>) Data).GetEnumerator(); + ResultSet.MoveNext(); + } + } + + #endregion + + public override bool HasRows + { + get { return ResultSet != null && ResultSet.Current.Length > 0; } + } + + public override bool Read() + { + if (Rows == null) + { + Rows = ((IEnumerable>) ResultSet.Current).GetEnumerator(); + } + return Rows.MoveNext(); + } + + public override bool NextResult() + { + if (Data == null || !ResultSet.MoveNext()) + { + return false; + } + Rows = ((IEnumerable>)ResultSet.Current).GetEnumerator(); + return true; + } + + public override object GetValue(int ordinal) + { + return this[ordinal]; + } + + public override object this[string name] + { + get { return Rows.Current[name]; } + } + + public override object this[int ordinal] + { + get { return Rows.Current[Rows.Current.Keys.AsEnumerable().ToArray()[ordinal]]; } + } + + public ReadOnlyCollection GetColumnSchema() + { + if (ResultSet?.Current == null || ResultSet.Current.Length <= 0) + { + return new ReadOnlyCollection(new List()); + } + + List columns = new List(); + for (int i = 0; i < ResultSet.Current[0].Count; i++) + { + columns.Add(new Mock().Object); + } + return new ReadOnlyCollection(columns); + } + + public override int FieldCount { get { return Rows?.Current.Count ?? 0; } } + + #region Not Implemented + + public override bool GetBoolean(int ordinal) + { + throw new NotImplementedException(); + } + + public override byte GetByte(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + { + throw new NotImplementedException(); + } + + public override char GetChar(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + { + throw new NotImplementedException(); + } + + public override string GetDataTypeName(int ordinal) + { + throw new NotImplementedException(); + } + + public override DateTime GetDateTime(int ordinal) + { + throw new NotImplementedException(); + } + + public override decimal GetDecimal(int ordinal) + { + throw new NotImplementedException(); + } + + public override double GetDouble(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetOrdinal(string name) + { + throw new NotImplementedException(); + } + + public override string GetName(int ordinal) + { + throw new NotImplementedException(); + } + + public override long GetInt64(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetInt32(int ordinal) + { + throw new NotImplementedException(); + } + + public override short GetInt16(int ordinal) + { + throw new NotImplementedException(); + } + + public override Guid GetGuid(int ordinal) + { + throw new NotImplementedException(); + } + + public override float GetFloat(int ordinal) + { + throw new NotImplementedException(); + } + + public override Type GetFieldType(int ordinal) + { + throw new NotImplementedException(); + } + + public override string GetString(int ordinal) + { + throw new NotImplementedException(); + } + + public override int GetValues(object[] values) + { + throw new NotImplementedException(); + } + + public override bool IsDBNull(int ordinal) + { + throw new NotImplementedException(); + } + + public override IEnumerator GetEnumerator() + { + throw new NotImplementedException(); + } + + public override int Depth { get; } + public override bool IsClosed { get; } + public override int RecordsAffected { get; } + + #endregion + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs index b973bfd9..b1ee31bb 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/Utility/TestObjects.cs @@ -18,6 +18,7 @@ using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; using Microsoft.SqlTools.ServiceLayer.LanguageServices; using Microsoft.SqlTools.ServiceLayer.SqlContext; +using Microsoft.SqlTools.ServiceLayer.Test.Utility; using Xunit; namespace Microsoft.SqlTools.Test.Utility @@ -97,179 +98,18 @@ namespace Microsoft.SqlTools.Test.Utility } } - public class TestDataReader : DbDataReader - { - - #region Test Specific Implementations - - internal string SqlCommandText { get; set; } - - private const string tableNameTestCommand = "SELECT name FROM sys.tables"; - - private List> tableNamesTest = new List> - { - new Dictionary { {"name", "table1"} }, - new Dictionary { {"name", "table2"} } - }; - - private IEnumerator> tableEnumerator; - - #endregion - - public override bool GetBoolean(int ordinal) - { - throw new NotImplementedException(); - } - - public override byte GetByte(int ordinal) - { - throw new NotImplementedException(); - } - - public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) - { - throw new NotImplementedException(); - } - - public override char GetChar(int ordinal) - { - throw new NotImplementedException(); - } - - public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) - { - throw new NotImplementedException(); - } - - public override string GetDataTypeName(int ordinal) - { - throw new NotImplementedException(); - } - - public override DateTime GetDateTime(int ordinal) - { - throw new NotImplementedException(); - } - - public override decimal GetDecimal(int ordinal) - { - throw new NotImplementedException(); - } - - public override double GetDouble(int ordinal) - { - throw new NotImplementedException(); - } - - public override IEnumerator GetEnumerator() - { - throw new NotImplementedException(); - } - - public override int GetOrdinal(string name) - { - throw new NotImplementedException(); - } - - public override string GetName(int ordinal) - { - throw new NotImplementedException(); - } - - public override long GetInt64(int ordinal) - { - throw new NotImplementedException(); - } - - public override int GetInt32(int ordinal) - { - throw new NotImplementedException(); - } - - public override short GetInt16(int ordinal) - { - throw new NotImplementedException(); - } - - public override Guid GetGuid(int ordinal) - { - throw new NotImplementedException(); - } - - public override float GetFloat(int ordinal) - { - throw new NotImplementedException(); - } - - public override Type GetFieldType(int ordinal) - { - throw new NotImplementedException(); - } - - public override string GetString(int ordinal) - { - throw new NotImplementedException(); - } - - public override object GetValue(int ordinal) - { - throw new NotImplementedException(); - } - - public override int GetValues(object[] values) - { - throw new NotImplementedException(); - } - - public override bool IsDBNull(int ordinal) - { - throw new NotImplementedException(); - } - - public override bool NextResult() - { - throw new NotImplementedException(); - } - - public override bool Read() - { - if (tableEnumerator == null) - { - switch (SqlCommandText) - { - case tableNameTestCommand: - tableEnumerator = ((IEnumerable>)tableNamesTest).GetEnumerator(); - break; - default: - throw new NotImplementedException(); - } - } - return tableEnumerator.MoveNext(); - } - - public override int Depth { get; } - public override bool IsClosed { get; } - public override int RecordsAffected { get; } - - public override object this[string name] - { - get { return tableEnumerator.Current[name]; } - } - - public override object this[int ordinal] - { - get { return tableEnumerator.Current[tableEnumerator.Current.Keys.ToArray()[ordinal]]; } - } - - public override int FieldCount { get; } - public override bool HasRows { get; } - } - /// /// Test mock class for IDbCommand /// public class TestSqlCommand : DbCommand { + internal TestSqlCommand(Dictionary[][] data) + { + Data = data; + } + + internal Dictionary[][] Data { get; set; } + public override void Cancel() { throw new NotImplementedException(); @@ -306,7 +146,7 @@ namespace Microsoft.SqlTools.Test.Utility protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) { - return new TestDataReader {SqlCommandText = CommandText}; + return new TestDbDataReader(Data); } } @@ -315,6 +155,13 @@ namespace Microsoft.SqlTools.Test.Utility /// public class TestSqlConnection : DbConnection { + internal TestSqlConnection(Dictionary[][] data) + { + Data = data; + } + + internal Dictionary[][] Data { get; set; } + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) { throw new NotImplementedException(); @@ -342,7 +189,7 @@ namespace Microsoft.SqlTools.Test.Utility protected override DbCommand CreateDbCommand() { - return new TestSqlCommand(); + return new TestSqlCommand(Data); } public override void ChangeDatabase(string databaseName) @@ -358,7 +205,7 @@ namespace Microsoft.SqlTools.Test.Utility { public DbConnection CreateSqlConnection(string connectionString) { - return new TestSqlConnection() + return new TestSqlConnection(null) { ConnectionString = connectionString }; diff --git a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json index 3d023cd4..23c97d0b 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.Test/project.json +++ b/test/Microsoft.SqlTools.ServiceLayer.Test/project.json @@ -14,7 +14,8 @@ "moq.netcore": "4.4.0-beta8", "Microsoft.SqlTools.ServiceLayer": { "target": "project" - } + }, + "System.Diagnostics.TraceSource": "4.0.0" }, "testRunner": "xunit", "frameworks": {