diff --git a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
index 088d6f16..c8f52119 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/QueryExecution/QueryExecutionService.cs
@@ -4,9 +4,12 @@
//
using System;
using System.Collections.Concurrent;
+using System.Data.Common;
+using System.Data.SqlClient;
using System.IO;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests;
@@ -118,6 +121,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
internal SqlToolsSettings Settings { get; set; }
+ ///
+ /// Holds a map from the simple execute unique GUID and the underlying task that is being ran
+ ///
+ private readonly Lazy> simpleExecuteRequests =
+ new Lazy>(() => new ConcurrentDictionary());
+
+ ///
+ /// Holds a map from the simple execute unique GUID and the underlying task that is being ran
+ ///
+ internal ConcurrentDictionary ActiveSimpleExecuteRequests => simpleExecuteRequests.Value;
+
#endregion
///
@@ -173,82 +187,111 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
///
/// Handles a request to execute a string and return the result
///
- internal Task HandleSimpleExecuteRequest(SimpleExecuteParams executeParams,
+ internal async Task HandleSimpleExecuteRequest(SimpleExecuteParams executeParams,
RequestContext requestContext)
{
- ExecuteStringParams executeStringParams = new ExecuteStringParams
+ try
{
- Query = executeParams.QueryString,
- // generate guid as the owner uri to make sure every query is unique
- OwnerUri = Guid.NewGuid().ToString()
- };
-
- // get connection
- ConnectionInfo connInfo;
- if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connInfo))
- {
- return requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri);
- }
-
- if (connInfo.ConnectionDetails.MultipleActiveResultSets == null || connInfo.ConnectionDetails.MultipleActiveResultSets == false) {
- // if multipleActive result sets is not allowed, don't specific a connection and make the ownerURI the true owneruri
- connInfo = null;
- executeStringParams.OwnerUri = executeParams.OwnerUri;
- }
-
- Func queryCreateFailureAction = message => requestContext.SendError(message);
-
- ResultOnlyContext newContext = new ResultOnlyContext(requestContext);
-
- // handle sending event back when the query completes
- Query.QueryAsyncEventHandler queryComplete = async q =>
- {
- Query removedQuery;
- // check to make sure any results were recieved
- if (q.Batches.Length == 0 || q.Batches[0].ResultSets.Count == 0)
+ string randomUri = Guid.NewGuid().ToString();
+ ExecuteStringParams executeStringParams = new ExecuteStringParams
{
- await requestContext.SendError(SR.QueryServiceResultSetHasNoResults);
- ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery);
- return;
- }
+ Query = executeParams.QueryString,
+ // generate guid as the owner uri to make sure every query is unique
+ OwnerUri = randomUri
+ };
- var rowCount = q.Batches[0].ResultSets[0].RowCount;
- // check to make sure there is a safe amount of rows to load into memory
- if (rowCount > Int32.MaxValue)
+ // get connection
+ ConnectionInfo connInfo;
+ if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connInfo))
{
- await requestContext.SendError(SR.QueryServiceResultSetTooLarge);
- ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery);
+ await requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri);
return;
}
- SubsetParams subsetRequestParams = new SubsetParams
+ ConnectParams connectParams = new ConnectParams
{
- OwnerUri = executeStringParams.OwnerUri,
- BatchIndex = 0,
- ResultSetIndex = 0,
- RowsStartIndex = 0,
- RowsCount = Convert.ToInt32(rowCount)
+ OwnerUri = randomUri,
+ Connection = connInfo.ConnectionDetails,
+ Type = ConnectionType.Default
};
- // get the data to send back
- ResultSetSubset subset = await InterServiceResultSubset(subsetRequestParams);
- SimpleExecuteResult result = new SimpleExecuteResult
- {
- RowCount = q.Batches[0].ResultSets[0].RowCount,
- ColumnInfo = q.Batches[0].ResultSets[0].Columns,
- Rows = subset.Rows
- };
- await requestContext.SendResult(result);
- // remove the active query since we are done with it
- ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery);
- };
+
+ Task workTask = Task.Run(async () => {
+ await ConnectionService.Connect(connectParams);
- // handle sending error back when query fails
- Query.QueryAsyncErrorEventHandler queryFail = async (q, e) =>
+ ConnectionInfo newConn;
+ ConnectionService.TryFindConnection(randomUri, out newConn);
+
+ Func queryCreateFailureAction = message => requestContext.SendError(message);
+
+ ResultOnlyContext newContext = new ResultOnlyContext(requestContext);
+
+ // handle sending event back when the query completes
+ Query.QueryAsyncEventHandler queryComplete = async query =>
+ {
+ try
+ {
+ // check to make sure any results were recieved
+ if (query.Batches.Length == 0 || query.Batches[0].ResultSets.Count == 0)
+ {
+ await requestContext.SendError(SR.QueryServiceResultSetHasNoResults);
+ return;
+ }
+
+ var rowCount = query.Batches[0].ResultSets[0].RowCount;
+ // check to make sure there is a safe amount of rows to load into memory
+ if (rowCount > Int32.MaxValue)
+ {
+ await requestContext.SendError(SR.QueryServiceResultSetTooLarge);
+ return;
+ }
+
+ SubsetParams subsetRequestParams = new SubsetParams
+ {
+ OwnerUri = randomUri,
+ BatchIndex = 0,
+ ResultSetIndex = 0,
+ RowsStartIndex = 0,
+ RowsCount = Convert.ToInt32(rowCount)
+ };
+ // get the data to send back
+ ResultSetSubset subset = await InterServiceResultSubset(subsetRequestParams);
+ SimpleExecuteResult result = new SimpleExecuteResult
+ {
+ RowCount = query.Batches[0].ResultSets[0].RowCount,
+ ColumnInfo = query.Batches[0].ResultSets[0].Columns,
+ Rows = subset.Rows
+ };
+ await requestContext.SendResult(result);
+ }
+ finally
+ {
+ Query removedQuery;
+ Task removedTask;
+ // remove the active query since we are done with it
+ ActiveQueries.TryRemove(randomUri, out removedQuery);
+ ActiveSimpleExecuteRequests.TryRemove(randomUri, out removedTask);
+ ConnectionService.Disconnect(new DisconnectParams(){
+ OwnerUri = randomUri,
+ Type = null
+ });
+ }
+ };
+
+ // handle sending error back when query fails
+ Query.QueryAsyncErrorEventHandler queryFail = async (q, e) =>
+ {
+ await requestContext.SendError(e);
+ };
+
+ await InterServiceExecuteQuery(executeStringParams, newConn, newContext, null, queryCreateFailureAction, queryComplete, queryFail);
+ });
+
+ ActiveSimpleExecuteRequests.TryAdd(randomUri, workTask);
+ }
+ catch(Exception ex)
{
- await requestContext.SendError(e);
- };
-
- return InterServiceExecuteQuery(executeStringParams, connInfo, newContext, null, queryCreateFailureAction, queryComplete, queryFail);
+ await requestContext.SendError(ex.ToString());
+ }
}
///
diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs
index 2a37a1f1..802373da 100644
--- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs
+++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/QueryExecution/Execution/ServiceIntegrationTests.cs
@@ -4,6 +4,7 @@
//
using System;
+using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests;
@@ -431,15 +432,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
.Complete();
await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object);
- Query q;
- queryService.ActiveQueries.TryGetValue(Constants.OwnerUri, out q);
+ await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values);
- // wait on the task to finish
+ Query q = queryService.ActiveQueries.Values.First();
+ Assert.NotNull(q);
q.ExecutionTask.Wait();
efv.Validate();
Assert.Equal(0, queryService.ActiveQueries.Count);
+
}
[Fact]
@@ -452,8 +454,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
.Complete();
await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object);
- Query q;
- queryService.ActiveQueries.TryGetValue(Constants.OwnerUri, out q);
+ await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values);
+
+ Query q = queryService.ActiveQueries.Values.First();
+
+ Assert.NotNull(q);
// wait on the task to finish
q.ExecutionTask.Wait();
@@ -463,6 +468,41 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
Assert.Equal(0, queryService.ActiveQueries.Count);
}
+ [Fact]
+ public async Task SimpleExecuteMultipleQueriesTest()
+ {
+ var queryService = Common.GetPrimedExecutionService(Common.StandardTestDataSet, true, false, null);
+ var queryParams = new SimpleExecuteParams { OwnerUri = Constants.OwnerUri, QueryString = Constants.StandardQuery };
+ var efv1 = new EventFlowValidator()
+ .AddSimpleExecuteQueryResultValidator(Common.StandardTestDataSet)
+ .Complete();
+ var efv2 = new EventFlowValidator()
+ .AddSimpleExecuteQueryResultValidator(Common.StandardTestDataSet)
+ .Complete();
+ Task qT1 = queryService.HandleSimpleExecuteRequest(queryParams, efv1.Object);
+ Task qT2 = queryService.HandleSimpleExecuteRequest(queryParams, efv2.Object);
+
+ await Task.WhenAll(qT1, qT2);
+
+ await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values);
+
+ var queries = queryService.ActiveQueries.Values.Take(2).ToArray();
+ Query q1 = queries[0];
+ Query q2 = queries[1];
+
+ Assert.NotNull(q1);
+ Assert.NotNull(q2);
+
+ // wait on the task to finish
+ q1.ExecutionTask.Wait();
+ q2.ExecutionTask.Wait();
+
+ efv1.Validate();
+ efv2.Validate();
+
+ Assert.Equal(0, queryService.ActiveQueries.Count);
+ }
+
#endregion
private static WorkspaceService GetDefaultWorkspaceService(string query)