diff --git a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/ServerNode.cs b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/ServerNode.cs index f61d63af..b28e0ea5 100644 --- a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/ServerNode.cs +++ b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/ServerNode.cs @@ -12,6 +12,7 @@ using Microsoft.SqlTools.Utility; using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.SqlCore.Utility; using System.IO; +using Microsoft.SqlTools.SqlCore.Connection; namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel { @@ -26,7 +27,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel private SqlServerType sqlServerType; public ServerConnection serverConnection; - public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func groupBySchemaFlag = null) + public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func groupBySchemaFlag = null, SecurityToken? accessToken = null) : base() { Validate.IsNotNull(nameof(ObjectExplorerServerInfo), serverInfo); @@ -36,9 +37,8 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel var assembly = typeof(SqlCore.ObjectExplorer.SmoModel.SmoQuerier).Assembly; serviceProvider ??= ExtensionServiceProvider.CreateFromAssembliesInDirectory(Path.GetDirectoryName(assembly.Location), new string[] { Path.GetFileName(assembly.Location) }); - this.context = new Lazy(() => CreateContext(serviceProvider, groupBySchemaFlag)); this.serverConnection = serverConnection; - + this.context = new Lazy(() => CreateContext(serviceProvider, groupBySchemaFlag, accessToken)); NodeValue = serverInfo.ServerName; IsAlwaysLeaf = false; NodeType = NodeTypes.Server.ToString(); @@ -114,7 +114,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel - private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func groupBySchemaFlag = null) + private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func groupBySchemaFlag = null, SecurityToken token = null) { string exceptionMessage; @@ -123,7 +123,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel Server server = SmoWrapper.CreateServer(this.serverConnection); if (server != null) { - return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag) + return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag, token) { Parent = server, SqlServerType = this.sqlServerType diff --git a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/SmoQueryContext.cs b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/SmoQueryContext.cs index c4deb250..a170eecc 100644 --- a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/SmoQueryContext.cs +++ b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/SmoModel/SmoQueryContext.cs @@ -6,6 +6,7 @@ using System; using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.SqlCore.Connection; using Microsoft.SqlTools.SqlCore.ObjectExplorer.Nodes; namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel @@ -20,23 +21,26 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel private SmoObjectBase parent; private SmoWrapper smoWrapper; private ValidForFlag validFor = 0; - private Func groupBySchemaFlag; /// /// Creates a context object with a server to use as the basis for any queries /// /// - public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func groupBySchemaFlag = null) - : this(server, serviceProvider, null, groupBySchemaFlag) + public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func groupBySchemaFlag = null, SecurityToken token = null) + : this(server, serviceProvider, null, groupBySchemaFlag, token) { } - internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func groupBySchemaFlag = null) + internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func groupBySchemaFlag = null, SecurityToken token = null) { this.server = server; ServiceProvider = serviceProvider; this.smoWrapper = serverManager ?? new SmoWrapper(); - this.groupBySchemaFlag = groupBySchemaFlag ?? new Func(() => false); + this.GroupBySchemaFlag = groupBySchemaFlag ?? new Func(() => false); + if(token != null && !string.IsNullOrEmpty(token.Token)) + { + UpdateAccessToken(token.Token); + } } /// @@ -85,9 +89,18 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel } } + /// + /// Function that gets the group by schema + /// + /// + public Func GroupBySchemaFlag { get; set; } + + /// + /// Returns group by schema flag value. + /// public bool GroupBySchema { - get => groupBySchemaFlag(); + get => GroupBySchemaFlag(); } /// @@ -114,7 +127,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel /// new with all fields except the same public SmoQueryContext CopyWithParent(SmoObjectBase parent) { - SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.groupBySchemaFlag) + SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.GroupBySchemaFlag) { database = this.Database, Parent = parent, diff --git a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs index daa426a5..df24f7a9 100644 --- a/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs +++ b/src/Microsoft.SqlTools.SqlCore/ObjectExplorer/StatelessObjectExplorer.cs @@ -3,6 +3,7 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // using System; +using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -37,32 +38,16 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer { using (SqlConnection conn = new SqlConnection(connectionString)) { - ServerConnection connection; - if (accessToken != null) - { - connection = new ServerConnection(conn, accessToken as IRenewableToken); - } - else - { - connection = new ServerConnection(conn); - } - - try - { - return await Expand(connection, accessToken, nodePath, serverInfo, options, filters); - } - finally - { - if (connection.IsOpen) - { - connection.Disconnect(); - } - } + conn.AccessToken = accessToken?.Token; + conn.Open(); + ServerConnection connection = new ServerConnection(conn); + connection.AccessToken = accessToken as IRenewableToken; + return await Expand(connection, accessToken, nodePath, serverInfo, options, filters); } } /// - /// Expands the node at the given path and returns the child nodes. If parent is not null, it will skip expanding from the top and use the connection from the parent node. + /// Expands the node at the given path and returns the child nodes. /// /// Server connection to use for expanding the node. It will be used only if parent is null /// Access token to connect to the server. To be used in case of AAD based connections @@ -70,45 +55,31 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer /// Server information /// Object explorer expansion options /// Filters to be applied on the leaf nodes - /// Optional parent node. If provided, it will skip expanding from the top and and use the connection from the parent node /// /// - - - public static async Task Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null, TreeNode parent = null) + public static async Task Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null) { using (var taskCancellationTokenSource = new CancellationTokenSource()) { - try { var token = accessToken == null ? null : accessToken.Token; - var task = Task.Run(() => { TreeNode? node; - if (parent == null) + ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter, accessToken); + TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName); + if (nodePath == null || nodePath == string.Empty) { - ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter); - TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName); - - if (nodePath == null || nodePath == string.Empty) - { - nodePath = rootNode.GetNodePath(); - } - node = rootNode; - if (node == null) - { - // Return empty array if node is not found - return new TreeNode[0]; - } - node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token); + nodePath = rootNode.GetNodePath(); } - else + node = rootNode; + if (node == null) { - node = parent; + // Return empty array if node is not found + return new TreeNode[0]; } - + node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token); if (node != null) { return node.Expand(taskCancellationTokenSource.Token, token, filters); @@ -118,20 +89,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer throw new InvalidArgumentException($"Parent node not found for path {nodePath}"); } }); - - - if (await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) == task) - { - if (taskCancellationTokenSource.IsCancellationRequested) - { - throw new TimeoutException("The operation has timed out."); - } - return task.Result.ToArray(); - } - else - { - throw new TimeoutException("The operation has timed out."); - } + return await RunExpandTask(task, taskCancellationTokenSource, options.OperationTimeoutSeconds); } catch (Exception ex) { @@ -140,6 +98,60 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer } } + + /// + /// Expands the given node and returns the child nodes. + /// + /// Node to expand + /// Object explorer expansion options + /// Filters to be applied on the leaf nodes + /// Security token to connect to the server. To be used in case of AAD based connections + /// + public static async Task ExpandTreeNode(TreeNode node, ObjectExplorerOptions options, INodeFilter[]? filters = null, SecurityToken? securityToken = null) + { + if (node == null) + { + throw new ArgumentNullException(nameof(node)); + } + using (var taskCancellationTokenSource = new CancellationTokenSource()) + { + var expandTask = Task.Run(async () => + { + SmoQueryContext nodeContext = node.GetContextAs() ?? throw new ArgumentException("Node does not have a valid context"); + + if(options.GroupBySchemaFlagGetter != null) + { + nodeContext.GroupBySchemaFlag = options.GroupBySchemaFlagGetter; + } + if (!nodeContext.Server.ConnectionContext.IsOpen && securityToken != null) + { + var underlyingSqlConnection = nodeContext.Server.ConnectionContext.SqlConnectionObject; + underlyingSqlConnection.AccessToken = securityToken.Token; + await underlyingSqlConnection.OpenAsync(); + } + + return node.Expand(taskCancellationTokenSource.Token, securityToken?.Token, filters); + }); + + return await RunExpandTask(expandTask, taskCancellationTokenSource, options.OperationTimeoutSeconds); + } + } + + private static async Task RunExpandTask(Task> expansionTask, CancellationTokenSource taskCancellationTokenSource, int operationTimeoutSeconds) + { + if (await Task.WhenAny(expansionTask, Task.Delay(TimeSpan.FromSeconds(operationTimeoutSeconds))) == expansionTask) + { + if (taskCancellationTokenSource.IsCancellationRequested) + { + throw new TimeoutException("The operation has timed out."); + } + return expansionTask.Result.ToArray(); + } + else + { + throw new TimeoutException("The operation has timed out."); + } + } } } \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs index 85903372..05782777 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/StatelessObjectExplorerServiceTests.cs @@ -70,16 +70,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options); Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo"); - nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes[0]); + nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes[0], options, null, null); Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1"); - nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Tables")); + nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Tables"), options, null, null); Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1"); - nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "dbo.t1")); + nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "dbo.t1"), options, null, null); Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns"); - nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Columns")); + nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Columns"), options, null, null); Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1"); }); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/GroupBySchemaTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/GroupBySchemaTests.cs index 223771e3..690a645d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/GroupBySchemaTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/GroupBySchemaTests.cs @@ -44,7 +44,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer context = new Mock(new Server(), ExtensionServiceProvider.CreateDefaultServiceProvider(), () => { return enableGroupBySchema; - }); + }, null); context.CallBase = true; context.Object.ValidFor = ValidForFlag.None;