Multiple Connection Simple Execute (#421)

* change simple execute to open a new connection and close it every query

* updated tests for simple execute

* removed an unnecessary connect

* refactored code to be more readable

* global try catch on simple execute

* added multiple execution test

* update execution to be asynchrous; update tests to account for asynchrounous nature
This commit is contained in:
Anthony Dresser
2017-07-28 13:35:46 -07:00
committed by GitHub
parent e453a19d00
commit 7ef81d0e54
2 changed files with 151 additions and 68 deletions

View File

@@ -4,9 +4,12 @@
// //
using System; using System;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Data.Common;
using System.Data.SqlClient;
using System.IO; using System.IO;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests;
@@ -118,6 +121,17 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// </summary> /// </summary>
internal SqlToolsSettings Settings { get; set; } internal SqlToolsSettings Settings { get; set; }
/// <summary>
/// Holds a map from the simple execute unique GUID and the underlying task that is being ran
/// </summary>
private readonly Lazy<ConcurrentDictionary<string, Task>> simpleExecuteRequests =
new Lazy<ConcurrentDictionary<string, Task>>(() => new ConcurrentDictionary<string, Task>());
/// <summary>
/// Holds a map from the simple execute unique GUID and the underlying task that is being ran
/// </summary>
internal ConcurrentDictionary<string, Task> ActiveSimpleExecuteRequests => simpleExecuteRequests.Value;
#endregion #endregion
/// <summary> /// <summary>
@@ -173,82 +187,111 @@ namespace Microsoft.SqlTools.ServiceLayer.QueryExecution
/// <summary> /// <summary>
/// Handles a request to execute a string and return the result /// Handles a request to execute a string and return the result
/// </summary> /// </summary>
internal Task HandleSimpleExecuteRequest(SimpleExecuteParams executeParams, internal async Task HandleSimpleExecuteRequest(SimpleExecuteParams executeParams,
RequestContext<SimpleExecuteResult> requestContext) RequestContext<SimpleExecuteResult> requestContext)
{ {
ExecuteStringParams executeStringParams = new ExecuteStringParams try
{ {
Query = executeParams.QueryString, string randomUri = Guid.NewGuid().ToString();
// generate guid as the owner uri to make sure every query is unique ExecuteStringParams executeStringParams = new ExecuteStringParams
OwnerUri = Guid.NewGuid().ToString()
};
// get connection
ConnectionInfo connInfo;
if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connInfo))
{
return requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri);
}
if (connInfo.ConnectionDetails.MultipleActiveResultSets == null || connInfo.ConnectionDetails.MultipleActiveResultSets == false) {
// if multipleActive result sets is not allowed, don't specific a connection and make the ownerURI the true owneruri
connInfo = null;
executeStringParams.OwnerUri = executeParams.OwnerUri;
}
Func<string, Task> queryCreateFailureAction = message => requestContext.SendError(message);
ResultOnlyContext<SimpleExecuteResult> newContext = new ResultOnlyContext<SimpleExecuteResult>(requestContext);
// handle sending event back when the query completes
Query.QueryAsyncEventHandler queryComplete = async q =>
{
Query removedQuery;
// check to make sure any results were recieved
if (q.Batches.Length == 0 || q.Batches[0].ResultSets.Count == 0)
{ {
await requestContext.SendError(SR.QueryServiceResultSetHasNoResults); Query = executeParams.QueryString,
ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery); // generate guid as the owner uri to make sure every query is unique
return; OwnerUri = randomUri
} };
var rowCount = q.Batches[0].ResultSets[0].RowCount; // get connection
// check to make sure there is a safe amount of rows to load into memory ConnectionInfo connInfo;
if (rowCount > Int32.MaxValue) if (!ConnectionService.TryFindConnection(executeParams.OwnerUri, out connInfo))
{ {
await requestContext.SendError(SR.QueryServiceResultSetTooLarge); await requestContext.SendError(SR.QueryServiceQueryInvalidOwnerUri);
ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery);
return; return;
} }
SubsetParams subsetRequestParams = new SubsetParams ConnectParams connectParams = new ConnectParams
{ {
OwnerUri = executeStringParams.OwnerUri, OwnerUri = randomUri,
BatchIndex = 0, Connection = connInfo.ConnectionDetails,
ResultSetIndex = 0, Type = ConnectionType.Default
RowsStartIndex = 0,
RowsCount = Convert.ToInt32(rowCount)
}; };
// get the data to send back
ResultSetSubset subset = await InterServiceResultSubset(subsetRequestParams); Task workTask = Task.Run(async () => {
SimpleExecuteResult result = new SimpleExecuteResult await ConnectionService.Connect(connectParams);
{
RowCount = q.Batches[0].ResultSets[0].RowCount,
ColumnInfo = q.Batches[0].ResultSets[0].Columns,
Rows = subset.Rows
};
await requestContext.SendResult(result);
// remove the active query since we are done with it
ActiveQueries.TryRemove(executeStringParams.OwnerUri, out removedQuery);
};
// handle sending error back when query fails ConnectionInfo newConn;
Query.QueryAsyncErrorEventHandler queryFail = async (q, e) => ConnectionService.TryFindConnection(randomUri, out newConn);
Func<string, Task> queryCreateFailureAction = message => requestContext.SendError(message);
ResultOnlyContext<SimpleExecuteResult> newContext = new ResultOnlyContext<SimpleExecuteResult>(requestContext);
// handle sending event back when the query completes
Query.QueryAsyncEventHandler queryComplete = async query =>
{
try
{
// check to make sure any results were recieved
if (query.Batches.Length == 0 || query.Batches[0].ResultSets.Count == 0)
{
await requestContext.SendError(SR.QueryServiceResultSetHasNoResults);
return;
}
var rowCount = query.Batches[0].ResultSets[0].RowCount;
// check to make sure there is a safe amount of rows to load into memory
if (rowCount > Int32.MaxValue)
{
await requestContext.SendError(SR.QueryServiceResultSetTooLarge);
return;
}
SubsetParams subsetRequestParams = new SubsetParams
{
OwnerUri = randomUri,
BatchIndex = 0,
ResultSetIndex = 0,
RowsStartIndex = 0,
RowsCount = Convert.ToInt32(rowCount)
};
// get the data to send back
ResultSetSubset subset = await InterServiceResultSubset(subsetRequestParams);
SimpleExecuteResult result = new SimpleExecuteResult
{
RowCount = query.Batches[0].ResultSets[0].RowCount,
ColumnInfo = query.Batches[0].ResultSets[0].Columns,
Rows = subset.Rows
};
await requestContext.SendResult(result);
}
finally
{
Query removedQuery;
Task removedTask;
// remove the active query since we are done with it
ActiveQueries.TryRemove(randomUri, out removedQuery);
ActiveSimpleExecuteRequests.TryRemove(randomUri, out removedTask);
ConnectionService.Disconnect(new DisconnectParams(){
OwnerUri = randomUri,
Type = null
});
}
};
// handle sending error back when query fails
Query.QueryAsyncErrorEventHandler queryFail = async (q, e) =>
{
await requestContext.SendError(e);
};
await InterServiceExecuteQuery(executeStringParams, newConn, newContext, null, queryCreateFailureAction, queryComplete, queryFail);
});
ActiveSimpleExecuteRequests.TryAdd(randomUri, workTask);
}
catch(Exception ex)
{ {
await requestContext.SendError(e); await requestContext.SendError(ex.ToString());
}; }
return InterServiceExecuteQuery(executeStringParams, connInfo, newContext, null, queryCreateFailureAction, queryComplete, queryFail);
} }
/// <summary> /// <summary>

View File

@@ -4,6 +4,7 @@
// //
using System; using System;
using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.QueryExecution; using Microsoft.SqlTools.ServiceLayer.QueryExecution;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests; using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts.ExecuteRequests;
@@ -431,15 +432,16 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
.Complete(); .Complete();
await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object); await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object);
Query q; await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values);
queryService.ActiveQueries.TryGetValue(Constants.OwnerUri, out q);
// wait on the task to finish Query q = queryService.ActiveQueries.Values.First();
Assert.NotNull(q);
q.ExecutionTask.Wait(); q.ExecutionTask.Wait();
efv.Validate(); efv.Validate();
Assert.Equal(0, queryService.ActiveQueries.Count); Assert.Equal(0, queryService.ActiveQueries.Count);
} }
[Fact] [Fact]
@@ -452,8 +454,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
.Complete(); .Complete();
await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object); await queryService.HandleSimpleExecuteRequest(queryParams, efv.Object);
Query q; await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values);
queryService.ActiveQueries.TryGetValue(Constants.OwnerUri, out q);
Query q = queryService.ActiveQueries.Values.First();
Assert.NotNull(q);
// wait on the task to finish // wait on the task to finish
q.ExecutionTask.Wait(); q.ExecutionTask.Wait();
@@ -463,6 +468,41 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.QueryExecution.Execution
Assert.Equal(0, queryService.ActiveQueries.Count); Assert.Equal(0, queryService.ActiveQueries.Count);
} }
[Fact]
public async Task SimpleExecuteMultipleQueriesTest()
{
var queryService = Common.GetPrimedExecutionService(Common.StandardTestDataSet, true, false, null);
var queryParams = new SimpleExecuteParams { OwnerUri = Constants.OwnerUri, QueryString = Constants.StandardQuery };
var efv1 = new EventFlowValidator<SimpleExecuteResult>()
.AddSimpleExecuteQueryResultValidator(Common.StandardTestDataSet)
.Complete();
var efv2 = new EventFlowValidator<SimpleExecuteResult>()
.AddSimpleExecuteQueryResultValidator(Common.StandardTestDataSet)
.Complete();
Task qT1 = queryService.HandleSimpleExecuteRequest(queryParams, efv1.Object);
Task qT2 = queryService.HandleSimpleExecuteRequest(queryParams, efv2.Object);
await Task.WhenAll(qT1, qT2);
await Task.WhenAll(queryService.ActiveSimpleExecuteRequests.Values);
var queries = queryService.ActiveQueries.Values.Take(2).ToArray();
Query q1 = queries[0];
Query q2 = queries[1];
Assert.NotNull(q1);
Assert.NotNull(q2);
// wait on the task to finish
q1.ExecutionTask.Wait();
q2.ExecutionTask.Wait();
efv1.Validate();
efv2.Validate();
Assert.Equal(0, queryService.ActiveQueries.Count);
}
#endregion #endregion
private static WorkspaceService<SqlToolsSettings> GetDefaultWorkspaceService(string query) private static WorkspaceService<SqlToolsSettings> GetDefaultWorkspaceService(string query)