diff --git a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs index 199a3ce7..d6907245 100644 --- a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs +++ b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs @@ -33,7 +33,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer /// Thrown when the parent node is not found /// Thrown when the operation times out. /// - public static TreeNode[] Expand(string connectionString, SecurityToken? accessToken, string nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null) + public static async Task Expand(string connectionString, SecurityToken? accessToken, string nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null) { using (SqlConnection conn = new SqlConnection(connectionString)) { @@ -47,72 +47,99 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer connection = new ServerConnection(conn); } - ServerNode serverNode = new ServerNode(serverInfo, connection, null, options.GroupBySchemaFlagGetter); - - TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName); - - if(nodePath == null || nodePath == string.Empty) + try { - nodePath = rootNode.GetNodePath(); + return await Expand(connection, accessToken, nodePath, serverInfo, options, filters); } - - using (var taskCancellationTokenSource = new CancellationTokenSource()) + finally { - TreeNode? node = rootNode; - if (node == null) + if (connection.IsOpen) { - // Return empty array if node is not found - return new TreeNode[0]; - } - - if (Monitor.TryEnter(node.BuildingMetadataLock, options.OperationTimeoutSeconds)) - { - try - { - var token = accessToken == null ? null : accessToken.Token; - - var task = Task.Run(() => - { - var node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token); - if (node != null) - { - return node.Expand(taskCancellationTokenSource.Token, token, filters); - } else - { - throw new InvalidArgumentException($"Parent node not found for path {nodePath}"); - } - }); - - if (task.Wait(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) - { - if (taskCancellationTokenSource.IsCancellationRequested) - { - throw new TimeoutException("The operation has timed out."); - } - return task.Result.ToArray(); - } - else - { - throw new TimeoutException("The operation has timed out."); - } - } - finally - { - if (connection.IsOpen) - { - connection.Disconnect(); - } - Monitor.Exit(node.BuildingMetadataLock); - } - } - else - { - throw new TimeoutException("The operation has timed out. Could not acquire the lock to build metadata for the node."); + connection.Disconnect(); } } } } + /// + /// Expands the node at the given path and returns the child nodes. If parent is not null, it will skip expanding from the top and use the connection from the parent node. + /// + /// Server connection to use for expanding the node. It will be used only if parent is null + /// Access token to connect to the server. To be used in case of AAD based connections + /// Path of the node to expand. Will be used only if parent is null + /// Server information + /// Object explorer expansion options + /// Filters to be applied on the leaf nodes + /// Optional parent node. If provided, it will skip expanding from the top and and use the connection from the parent node + /// + /// + + + public static async Task Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null, TreeNode parent = null) + { + using (var taskCancellationTokenSource = new CancellationTokenSource()) + { + + try + { + var token = accessToken == null ? null : accessToken.Token; + + var task = Task.Run(() => + { + TreeNode? node; + if (parent == null) + { + ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter); + TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName); + + if (nodePath == null || nodePath == string.Empty) + { + nodePath = rootNode.GetNodePath(); + } + node = rootNode; + if (node == null) + { + // Return empty array if node is not found + return new TreeNode[0]; + } + node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token); + } + else + { + node = parent; + } + + if (node != null) + { + return node.Expand(taskCancellationTokenSource.Token, token, filters); + } + else + { + throw new InvalidArgumentException($"Parent node not found for path {nodePath}"); + } + }); + + + if (await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) == task) + { + if (taskCancellationTokenSource.IsCancellationRequested) + { + throw new TimeoutException("The operation has timed out."); + } + return task.Result.ToArray(); + } + else + { + throw new TimeoutException("The operation has timed out."); + } + } + catch (Exception ex) + { + throw ex; + } + + } + } } } \ No newline at end of file diff --git a/src/Microsoft.SqlTools.SqlCore/Scripting/AsyncScriptAsScriptingOperation.cs b/src/Microsoft.SqlTools.SqlCore/Scripting/AsyncScriptAsScriptingOperation.cs index 01477bdd..cd398c8a 100644 --- a/src/Microsoft.SqlTools.SqlCore/Scripting/AsyncScriptAsScriptingOperation.cs +++ b/src/Microsoft.SqlTools.SqlCore/Scripting/AsyncScriptAsScriptingOperation.cs @@ -6,17 +6,46 @@ using System; using System.Threading.Tasks; using Azure.Core; +using Microsoft.SqlServer.Management.Common; using Microsoft.SqlTools.SqlCore.Scripting.Contracts; namespace Microsoft.SqlTools.SqlCore.Scripting { public class AsyncScriptAsScriptingOperation { - public static async Task GetScriptAsScript(ScriptingParams parameters, AccessToken? accessToken) + public static async Task GetScriptAsScript(ScriptingParams parameters) + { + var scriptAsOperation = new ScriptAsScriptingOperation(parameters, string.Empty); + return await ExecuteScriptAs(scriptAsOperation); + } + + /// + /// Gets the script as script like select, insert, update, drop and create for the given scripting parameters. + /// + /// scripting parameters that contains the object to script and the scripting options + /// access token to connect to the server. To be used in case of AAD based connections + /// script as script + public static async Task GetScriptAsScript(ScriptingParams parameters, ServerConnection? serverConnection, AccessToken? accessToken) { var scriptAsOperation = new ScriptAsScriptingOperation(parameters, accessToken?.Token); - TaskCompletionSource scriptAsTask = new TaskCompletionSource(); + return await ExecuteScriptAs(scriptAsOperation); + } + /// + /// Gets the script as script like select, insert, update, drop and create for the given scripting parameters. + /// + /// scripting parameters that contains the object to script and the scripting options + /// server connection to use for scripting + /// script as script + public static async Task GetScriptAsScript(ScriptingParams parameters, ServerConnection? serverConnection) + { + var scriptAsOperation = new ScriptAsScriptingOperation(parameters, serverConnection); + return await ExecuteScriptAs(scriptAsOperation); + } + + private static async Task ExecuteScriptAs(ScriptAsScriptingOperation scriptAsOperation) + { + TaskCompletionSource scriptAsTask = new TaskCompletionSource(); scriptAsOperation.CompleteNotification += (sender, args) => { if (args.HasError) diff --git a/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs b/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs index 8a9809fb..9101a37a 100644 --- a/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs +++ b/src/Microsoft.SqlTools.SqlCore/Scripting/ScriptAsScriptingOperation.cs @@ -47,13 +47,13 @@ namespace Microsoft.SqlTools.SqlCore.Scripting { SqlConnection sqlConnection = new SqlConnection(this.Parameters.ConnectionString); sqlConnection.RetryLogicProvider = SqlRetryProviders.ServerlessDBRetryProvider(); - if (azureAccountToken != null) + if (!string.IsNullOrEmpty(azureAccountToken)) { sqlConnection.AccessToken = azureAccountToken; } ServerConnection = new ServerConnection(sqlConnection); - if (azureAccountToken != null) + if (!string.IsNullOrEmpty(azureAccountToken)) { ServerConnection.AccessToken = new AzureAccessToken(azureAccountToken); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs index 3c42704b..85903372 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs @@ -4,12 +4,12 @@ // using System; -using System.Threading.Tasks; -using NUnit.Framework; -using Microsoft.SqlTools.ServiceLayer.Test.Common; -using Microsoft.SqlTools.SqlCore.ObjectExplorer; -using Microsoft.SqlTools.ServiceLayer.Test.Common.Extensions; using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using Microsoft.SqlTools.ServiceLayer.Test.Common.Extensions; +using Microsoft.SqlTools.SqlCore.ObjectExplorer; +using NUnit.Framework; namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer { @@ -48,12 +48,42 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer { serverInfo.DatabaseName = testdbName; var pathWithDb = string.Format(oePath, testdbName); - var nodes = StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options); + var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options); Assert.True(nodes.Any(node => node.Label == childLabel), $"Expansion result for {pathWithDb} does not contain node {childLabel}"); }); } + [Test] + public async Task ProvidingNodeShouldSkipExpandingFromTop() + { + var query = @"Create table t1 (c1 int) + GO + Create table t2 (c1 int) + GO"; + await RunTest(databaseName, query, "testdb", async (testdbName, connectionString) => + { + serverInfo.DatabaseName = testdbName; + var oePath = ""; + var pathWithDb = string.Format(oePath, testdbName); + + var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options); + Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo"); + + nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes[0]); + Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1"); + + nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Tables")); + Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1"); + + nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "dbo.t1")); + Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns"); + + nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Columns")); + Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1"); + }); + } + private async Task RunTest(string databaseName, string query, string testDbPrefix, Func test) { SqlTestDb? testDb = null; diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Scripting/AsyncScriptAsScriptingOperationTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Scripting/AsyncScriptAsScriptingOperationTests.cs index dab9366f..6438083d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Scripting/AsyncScriptAsScriptingOperationTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Scripting/AsyncScriptAsScriptingOperationTests.cs @@ -105,7 +105,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Scripting var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, query, "ScriptingTests"); scriptingParams.ConnectionString = testDb.ConnectionString; - var actualScript = await AsyncScriptAsScriptingOperation.GetScriptAsScript(scriptingParams, null); + var actualScript = await AsyncScriptAsScriptingOperation.GetScriptAsScript(scriptingParams); foreach(var expectedStr in expectedScriptContents) {