mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-13 17:23:02 -05:00
Fixing connections for stateless object explorer to work with fabric tokens (#2252)
This commit is contained in:
@@ -12,6 +12,7 @@ using Microsoft.SqlTools.Utility;
|
||||
using Microsoft.SqlTools.Extensibility;
|
||||
using Microsoft.SqlTools.SqlCore.Utility;
|
||||
using System.IO;
|
||||
using Microsoft.SqlTools.SqlCore.Connection;
|
||||
|
||||
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
{
|
||||
@@ -26,7 +27,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
private SqlServerType sqlServerType;
|
||||
public ServerConnection serverConnection;
|
||||
|
||||
public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null)
|
||||
public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null, SecurityToken? accessToken = null)
|
||||
: base()
|
||||
{
|
||||
Validate.IsNotNull(nameof(ObjectExplorerServerInfo), serverInfo);
|
||||
@@ -36,9 +37,8 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
|
||||
var assembly = typeof(SqlCore.ObjectExplorer.SmoModel.SmoQuerier).Assembly;
|
||||
serviceProvider ??= ExtensionServiceProvider.CreateFromAssembliesInDirectory(Path.GetDirectoryName(assembly.Location), new string[] { Path.GetFileName(assembly.Location) });
|
||||
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag));
|
||||
this.serverConnection = serverConnection;
|
||||
|
||||
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag, accessToken));
|
||||
NodeValue = serverInfo.ServerName;
|
||||
IsAlwaysLeaf = false;
|
||||
NodeType = NodeTypes.Server.ToString();
|
||||
@@ -114,7 +114,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
|
||||
|
||||
|
||||
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null)
|
||||
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
|
||||
{
|
||||
string exceptionMessage;
|
||||
|
||||
@@ -123,7 +123,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
Server server = SmoWrapper.CreateServer(this.serverConnection);
|
||||
if (server != null)
|
||||
{
|
||||
return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag)
|
||||
return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag, token)
|
||||
{
|
||||
Parent = server,
|
||||
SqlServerType = this.sqlServerType
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
using System;
|
||||
using Microsoft.SqlServer.Management.Smo;
|
||||
using Microsoft.SqlTools.Extensibility;
|
||||
using Microsoft.SqlTools.SqlCore.Connection;
|
||||
using Microsoft.SqlTools.SqlCore.ObjectExplorer.Nodes;
|
||||
|
||||
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
@@ -20,23 +21,26 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
private SmoObjectBase parent;
|
||||
private SmoWrapper smoWrapper;
|
||||
private ValidForFlag validFor = 0;
|
||||
private Func<bool> groupBySchemaFlag;
|
||||
|
||||
/// <summary>
|
||||
/// Creates a context object with a server to use as the basis for any queries
|
||||
/// </summary>
|
||||
/// <param name="server"></param>
|
||||
public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null)
|
||||
: this(server, serviceProvider, null, groupBySchemaFlag)
|
||||
public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
|
||||
: this(server, serviceProvider, null, groupBySchemaFlag, token)
|
||||
{
|
||||
}
|
||||
|
||||
internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null)
|
||||
internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
|
||||
{
|
||||
this.server = server;
|
||||
ServiceProvider = serviceProvider;
|
||||
this.smoWrapper = serverManager ?? new SmoWrapper();
|
||||
this.groupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false);
|
||||
this.GroupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false);
|
||||
if(token != null && !string.IsNullOrEmpty(token.Token))
|
||||
{
|
||||
UpdateAccessToken(token.Token);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -85,9 +89,18 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Function that gets the group by schema
|
||||
/// </summary>
|
||||
/// <value></value>
|
||||
public Func<bool> GroupBySchemaFlag { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Returns group by schema flag value.
|
||||
/// </summary>
|
||||
public bool GroupBySchema
|
||||
{
|
||||
get => groupBySchemaFlag();
|
||||
get => GroupBySchemaFlag();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
@@ -114,7 +127,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||
/// <returns>new <see cref="SmoQueryContext"/> with all fields except <see cref="Parent"/> the same</returns>
|
||||
public SmoQueryContext CopyWithParent(SmoObjectBase parent)
|
||||
{
|
||||
SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.groupBySchemaFlag)
|
||||
SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.GroupBySchemaFlag)
|
||||
{
|
||||
database = this.Database,
|
||||
Parent = parent,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||
//
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
@@ -37,32 +38,16 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
||||
{
|
||||
using (SqlConnection conn = new SqlConnection(connectionString))
|
||||
{
|
||||
ServerConnection connection;
|
||||
if (accessToken != null)
|
||||
{
|
||||
connection = new ServerConnection(conn, accessToken as IRenewableToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
connection = new ServerConnection(conn);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
return await Expand(connection, accessToken, nodePath, serverInfo, options, filters);
|
||||
}
|
||||
finally
|
||||
{
|
||||
if (connection.IsOpen)
|
||||
{
|
||||
connection.Disconnect();
|
||||
}
|
||||
}
|
||||
conn.AccessToken = accessToken?.Token;
|
||||
conn.Open();
|
||||
ServerConnection connection = new ServerConnection(conn);
|
||||
connection.AccessToken = accessToken as IRenewableToken;
|
||||
return await Expand(connection, accessToken, nodePath, serverInfo, options, filters);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Expands the node at the given path and returns the child nodes. If parent is not null, it will skip expanding from the top and use the connection from the parent node.
|
||||
/// Expands the node at the given path and returns the child nodes.
|
||||
/// </summary>
|
||||
/// <param name="serverConnection"> Server connection to use for expanding the node. It will be used only if parent is null </param>
|
||||
/// <param name="accessToken"> Access token to connect to the server. To be used in case of AAD based connections </param>
|
||||
@@ -70,45 +55,31 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
||||
/// <param name="serverInfo"> Server information </param>
|
||||
/// <param name="options"> Object explorer expansion options </param>
|
||||
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
|
||||
/// <param name="parent"> Optional parent node. If provided, it will skip expanding from the top and and use the connection from the parent node </param>
|
||||
/// <returns></returns>
|
||||
/// </summary>
|
||||
|
||||
|
||||
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null, TreeNode parent = null)
|
||||
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null)
|
||||
{
|
||||
using (var taskCancellationTokenSource = new CancellationTokenSource())
|
||||
{
|
||||
|
||||
try
|
||||
{
|
||||
var token = accessToken == null ? null : accessToken.Token;
|
||||
|
||||
var task = Task.Run(() =>
|
||||
{
|
||||
TreeNode? node;
|
||||
if (parent == null)
|
||||
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter, accessToken);
|
||||
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
|
||||
if (nodePath == null || nodePath == string.Empty)
|
||||
{
|
||||
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter);
|
||||
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
|
||||
|
||||
if (nodePath == null || nodePath == string.Empty)
|
||||
{
|
||||
nodePath = rootNode.GetNodePath();
|
||||
}
|
||||
node = rootNode;
|
||||
if (node == null)
|
||||
{
|
||||
// Return empty array if node is not found
|
||||
return new TreeNode[0];
|
||||
}
|
||||
node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
|
||||
nodePath = rootNode.GetNodePath();
|
||||
}
|
||||
else
|
||||
node = rootNode;
|
||||
if (node == null)
|
||||
{
|
||||
node = parent;
|
||||
// Return empty array if node is not found
|
||||
return new TreeNode[0];
|
||||
}
|
||||
|
||||
node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
|
||||
if (node != null)
|
||||
{
|
||||
return node.Expand(taskCancellationTokenSource.Token, token, filters);
|
||||
@@ -118,20 +89,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
||||
throw new InvalidArgumentException($"Parent node not found for path {nodePath}");
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
if (await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) == task)
|
||||
{
|
||||
if (taskCancellationTokenSource.IsCancellationRequested)
|
||||
{
|
||||
throw new TimeoutException("The operation has timed out.");
|
||||
}
|
||||
return task.Result.ToArray();
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new TimeoutException("The operation has timed out.");
|
||||
}
|
||||
return await RunExpandTask(task, taskCancellationTokenSource, options.OperationTimeoutSeconds);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
@@ -140,6 +98,60 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Expands the given node and returns the child nodes.
|
||||
/// </summary>
|
||||
/// <param name="node"> Node to expand </param>
|
||||
/// <param name="options"> Object explorer expansion options </param>
|
||||
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
|
||||
/// <param name="securityToken"> Security token to connect to the server. To be used in case of AAD based connections </param>
|
||||
/// <returns></returns>
|
||||
public static async Task<TreeNode[]> ExpandTreeNode(TreeNode node, ObjectExplorerOptions options, INodeFilter[]? filters = null, SecurityToken? securityToken = null)
|
||||
{
|
||||
if (node == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(node));
|
||||
}
|
||||
|
||||
using (var taskCancellationTokenSource = new CancellationTokenSource())
|
||||
{
|
||||
var expandTask = Task.Run(async () =>
|
||||
{
|
||||
SmoQueryContext nodeContext = node.GetContextAs<SmoQueryContext>() ?? throw new ArgumentException("Node does not have a valid context");
|
||||
|
||||
if(options.GroupBySchemaFlagGetter != null)
|
||||
{
|
||||
nodeContext.GroupBySchemaFlag = options.GroupBySchemaFlagGetter;
|
||||
}
|
||||
if (!nodeContext.Server.ConnectionContext.IsOpen && securityToken != null)
|
||||
{
|
||||
var underlyingSqlConnection = nodeContext.Server.ConnectionContext.SqlConnectionObject;
|
||||
underlyingSqlConnection.AccessToken = securityToken.Token;
|
||||
await underlyingSqlConnection.OpenAsync();
|
||||
}
|
||||
|
||||
return node.Expand(taskCancellationTokenSource.Token, securityToken?.Token, filters);
|
||||
});
|
||||
|
||||
return await RunExpandTask(expandTask, taskCancellationTokenSource, options.OperationTimeoutSeconds);
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<TreeNode[]> RunExpandTask(Task<IList<TreeNode>> expansionTask, CancellationTokenSource taskCancellationTokenSource, int operationTimeoutSeconds)
|
||||
{
|
||||
if (await Task.WhenAny(expansionTask, Task.Delay(TimeSpan.FromSeconds(operationTimeoutSeconds))) == expansionTask)
|
||||
{
|
||||
if (taskCancellationTokenSource.IsCancellationRequested)
|
||||
{
|
||||
throw new TimeoutException("The operation has timed out.");
|
||||
}
|
||||
return expansionTask.Result.ToArray();
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new TimeoutException("The operation has timed out.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,16 +70,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer
|
||||
var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
|
||||
Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo");
|
||||
|
||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes[0]);
|
||||
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes[0], options, null, null);
|
||||
Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1");
|
||||
|
||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Tables"));
|
||||
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Tables"), options, null, null);
|
||||
Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1");
|
||||
|
||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "dbo.t1"));
|
||||
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "dbo.t1"), options, null, null);
|
||||
Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns");
|
||||
|
||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Columns"));
|
||||
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Columns"), options, null, null);
|
||||
Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1");
|
||||
});
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
|
||||
context = new Mock<SmoQueryContext>(new Server(), ExtensionServiceProvider.CreateDefaultServiceProvider(), () =>
|
||||
{
|
||||
return enableGroupBySchema;
|
||||
});
|
||||
}, null);
|
||||
context.CallBase = true;
|
||||
context.Object.ValidFor = ValidForFlag.None;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user