Adding ability to do stateless OE expansions from a parent node other than root (#2211)

This commit is contained in:
Aasim Khan
2023-09-06 22:25:02 +00:00
committed by GitHub
parent a56cf5a277
commit 4dbd1b05a4
5 changed files with 155 additions and 69 deletions

View File

@@ -33,7 +33,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
/// <exception cref="ArgumentNullException"> Thrown when the parent node is not found </exception>
/// <exception cref="TimeoutException"> Thrown when the operation times out.</exception> <summary>
/// </summary>
public static TreeNode[] Expand(string connectionString, SecurityToken? accessToken, string nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null)
public static async Task<TreeNode[]> Expand(string connectionString, SecurityToken? accessToken, string nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null)
{
using (SqlConnection conn = new SqlConnection(connectionString))
{
@@ -47,72 +47,99 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
connection = new ServerConnection(conn);
}
ServerNode serverNode = new ServerNode(serverInfo, connection, null, options.GroupBySchemaFlagGetter);
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
if(nodePath == null || nodePath == string.Empty)
try
{
nodePath = rootNode.GetNodePath();
return await Expand(connection, accessToken, nodePath, serverInfo, options, filters);
}
using (var taskCancellationTokenSource = new CancellationTokenSource())
finally
{
TreeNode? node = rootNode;
if (node == null)
if (connection.IsOpen)
{
// Return empty array if node is not found
return new TreeNode[0];
}
if (Monitor.TryEnter(node.BuildingMetadataLock, options.OperationTimeoutSeconds))
{
try
{
var token = accessToken == null ? null : accessToken.Token;
var task = Task.Run(() =>
{
var node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
if (node != null)
{
return node.Expand(taskCancellationTokenSource.Token, token, filters);
} else
{
throw new InvalidArgumentException($"Parent node not found for path {nodePath}");
}
});
if (task.Wait(TimeSpan.FromSeconds(options.OperationTimeoutSeconds)))
{
if (taskCancellationTokenSource.IsCancellationRequested)
{
throw new TimeoutException("The operation has timed out.");
}
return task.Result.ToArray();
}
else
{
throw new TimeoutException("The operation has timed out.");
}
}
finally
{
if (connection.IsOpen)
{
connection.Disconnect();
}
Monitor.Exit(node.BuildingMetadataLock);
}
}
else
{
throw new TimeoutException("The operation has timed out. Could not acquire the lock to build metadata for the node.");
connection.Disconnect();
}
}
}
}
/// <summary>
/// Expands the node at the given path and returns the child nodes. If parent is not null, it will skip expanding from the top and use the connection from the parent node.
/// </summary>
/// <param name="serverConnection"> Server connection to use for expanding the node. It will be used only if parent is null </param>
/// <param name="accessToken"> Access token to connect to the server. To be used in case of AAD based connections </param>
/// <param name="nodePath"> Path of the node to expand. Will be used only if parent is null </param>
/// <param name="serverInfo"> Server information </param>
/// <param name="options"> Object explorer expansion options </param>
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
/// <param name="parent"> Optional parent node. If provided, it will skip expanding from the top and and use the connection from the parent node </param>
/// <returns></returns>
/// </summary>
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null, TreeNode parent = null)
{
using (var taskCancellationTokenSource = new CancellationTokenSource())
{
try
{
var token = accessToken == null ? null : accessToken.Token;
var task = Task.Run(() =>
{
TreeNode? node;
if (parent == null)
{
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter);
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
if (nodePath == null || nodePath == string.Empty)
{
nodePath = rootNode.GetNodePath();
}
node = rootNode;
if (node == null)
{
// Return empty array if node is not found
return new TreeNode[0];
}
node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
}
else
{
node = parent;
}
if (node != null)
{
return node.Expand(taskCancellationTokenSource.Token, token, filters);
}
else
{
throw new InvalidArgumentException($"Parent node not found for path {nodePath}");
}
});
if (await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) == task)
{
if (taskCancellationTokenSource.IsCancellationRequested)
{
throw new TimeoutException("The operation has timed out.");
}
return task.Result.ToArray();
}
else
{
throw new TimeoutException("The operation has timed out.");
}
}
catch (Exception ex)
{
throw ex;
}
}
}
}
}

View File

@@ -6,17 +6,46 @@
using System;
using System.Threading.Tasks;
using Azure.Core;
using Microsoft.SqlServer.Management.Common;
using Microsoft.SqlTools.SqlCore.Scripting.Contracts;
namespace Microsoft.SqlTools.SqlCore.Scripting
{
public class AsyncScriptAsScriptingOperation
{
public static async Task<string> GetScriptAsScript(ScriptingParams parameters, AccessToken? accessToken)
public static async Task<string> GetScriptAsScript(ScriptingParams parameters)
{
var scriptAsOperation = new ScriptAsScriptingOperation(parameters, string.Empty);
return await ExecuteScriptAs(scriptAsOperation);
}
/// <summary>
/// Gets the script as script like select, insert, update, drop and create for the given scripting parameters.
/// </summary>
/// <param name="parameters">scripting parameters that contains the object to script and the scripting options</param>
/// <param name="accessToken">access token to connect to the server. To be used in case of AAD based connections</param>
/// <returns>script as script</returns>
public static async Task<string> GetScriptAsScript(ScriptingParams parameters, ServerConnection? serverConnection, AccessToken? accessToken)
{
var scriptAsOperation = new ScriptAsScriptingOperation(parameters, accessToken?.Token);
TaskCompletionSource<string> scriptAsTask = new TaskCompletionSource<string>();
return await ExecuteScriptAs(scriptAsOperation);
}
/// <summary>
/// Gets the script as script like select, insert, update, drop and create for the given scripting parameters.
/// </summary>
/// <param name="parameters">scripting parameters that contains the object to script and the scripting options</param>
/// <param name="serverConnection">server connection to use for scripting</param>
/// <returns>script as script</returns>
public static async Task<string> GetScriptAsScript(ScriptingParams parameters, ServerConnection? serverConnection)
{
var scriptAsOperation = new ScriptAsScriptingOperation(parameters, serverConnection);
return await ExecuteScriptAs(scriptAsOperation);
}
private static async Task<string> ExecuteScriptAs(ScriptAsScriptingOperation scriptAsOperation)
{
TaskCompletionSource<string> scriptAsTask = new TaskCompletionSource<string>();
scriptAsOperation.CompleteNotification += (sender, args) =>
{
if (args.HasError)

View File

@@ -47,13 +47,13 @@ namespace Microsoft.SqlTools.SqlCore.Scripting
{
SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString);
sqlConnection.RetryLogicProvider = SqlRetryProviders.ServerlessDBRetryProvider();
if (azureAccountToken != null)
if (!string.IsNullOrEmpty(azureAccountToken))
{
sqlConnection.AccessToken = azureAccountToken;
}
ServerConnection = new ServerConnection(sqlConnection);
if (azureAccountToken != null)
if (!string.IsNullOrEmpty(azureAccountToken))
{
ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken);
}

View File

@@ -4,12 +4,12 @@
//
using System;
using System.Threading.Tasks;
using NUnit.Framework;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using Microsoft.SqlTools.SqlCore.ObjectExplorer;
using Microsoft.SqlTools.ServiceLayer.Test.Common.Extensions;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.SqlTools.ServiceLayer.Test.Common;
using Microsoft.SqlTools.ServiceLayer.Test.Common.Extensions;
using Microsoft.SqlTools.SqlCore.ObjectExplorer;
using NUnit.Framework;
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer
{
@@ -48,12 +48,42 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer
{
serverInfo.DatabaseName = testdbName;
var pathWithDb = string.Format(oePath, testdbName);
var nodes = StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
Assert.True(nodes.Any(node => node.Label == childLabel), $"Expansion result for {pathWithDb} does not contain node {childLabel}");
});
}
[Test]
public async Task ProvidingNodeShouldSkipExpandingFromTop()
{
var query = @"Create table t1 (c1 int)
GO
Create table t2 (c1 int)
GO";
await RunTest(databaseName, query, "testdb", async (testdbName, connectionString) =>
{
serverInfo.DatabaseName = testdbName;
var oePath = "";
var pathWithDb = string.Format(oePath, testdbName);
var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes[0]);
Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Tables"));
Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "dbo.t1"));
Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Columns"));
Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1");
});
}
private async Task RunTest(string databaseName, string query, string testDbPrefix, Func<string, string, Task> test)
{
SqlTestDb? testDb = null;

View File

@@ -105,7 +105,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Scripting
var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, query, "ScriptingTests");
scriptingParams.ConnectionString = testDb.ConnectionString;
var actualScript = await AsyncScriptAsScriptingOperation.GetScriptAsScript(scriptingParams, null);
var actualScript = await AsyncScriptAsScriptingOperation.GetScriptAsScript(scriptingParams);
foreach(var expectedStr in expectedScriptContents)
{