Fixing connections for stateless object explorer to work with fabric tokens (#2252)

This commit is contained in:
Aasim Khan
2023-09-27 02:07:36 +00:00
committed by GitHub
parent 78581e8508
commit 138a70efa0
5 changed files with 102 additions and 77 deletions

View File

@@ -12,6 +12,7 @@ using Microsoft.SqlTools.Utility;
using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.Extensibility;
using Microsoft.SqlTools.SqlCore.Utility; using Microsoft.SqlTools.SqlCore.Utility;
using System.IO; using System.IO;
using Microsoft.SqlTools.SqlCore.Connection;
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
{ {
@@ -26,7 +27,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
private SqlServerType sqlServerType; private SqlServerType sqlServerType;
public ServerConnection serverConnection; public ServerConnection serverConnection;
public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null) public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null, SecurityToken? accessToken = null)
: base() : base()
{ {
Validate.IsNotNull(nameof(ObjectExplorerServerInfo), serverInfo); Validate.IsNotNull(nameof(ObjectExplorerServerInfo), serverInfo);
@@ -36,9 +37,8 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
var assembly = typeof(SqlCore.ObjectExplorer.SmoModel.SmoQuerier).Assembly; var assembly = typeof(SqlCore.ObjectExplorer.SmoModel.SmoQuerier).Assembly;
serviceProvider ??= ExtensionServiceProvider.CreateFromAssembliesInDirectory(Path.GetDirectoryName(assembly.Location), new string[] { Path.GetFileName(assembly.Location) }); serviceProvider ??= ExtensionServiceProvider.CreateFromAssembliesInDirectory(Path.GetDirectoryName(assembly.Location), new string[] { Path.GetFileName(assembly.Location) });
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag));
this.serverConnection = serverConnection; this.serverConnection = serverConnection;
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag, accessToken));
NodeValue = serverInfo.ServerName; NodeValue = serverInfo.ServerName;
IsAlwaysLeaf = false; IsAlwaysLeaf = false;
NodeType = NodeTypes.Server.ToString(); NodeType = NodeTypes.Server.ToString();
@@ -114,7 +114,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null) private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
{ {
string exceptionMessage; string exceptionMessage;
@@ -123,7 +123,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
Server server = SmoWrapper.CreateServer(this.serverConnection); Server server = SmoWrapper.CreateServer(this.serverConnection);
if (server != null) if (server != null)
{ {
return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag) return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag, token)
{ {
Parent = server, Parent = server,
SqlServerType = this.sqlServerType SqlServerType = this.sqlServerType

View File

@@ -6,6 +6,7 @@
using System; using System;
using Microsoft.SqlServer.Management.Smo; using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.Extensibility;
using Microsoft.SqlTools.SqlCore.Connection;
using Microsoft.SqlTools.SqlCore.ObjectExplorer.Nodes; using Microsoft.SqlTools.SqlCore.ObjectExplorer.Nodes;
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
@@ -20,23 +21,26 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
private SmoObjectBase parent; private SmoObjectBase parent;
private SmoWrapper smoWrapper; private SmoWrapper smoWrapper;
private ValidForFlag validFor = 0; private ValidForFlag validFor = 0;
private Func<bool> groupBySchemaFlag;
/// <summary> /// <summary>
/// Creates a context object with a server to use as the basis for any queries /// Creates a context object with a server to use as the basis for any queries
/// </summary> /// </summary>
/// <param name="server"></param> /// <param name="server"></param>
public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null) public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
: this(server, serviceProvider, null, groupBySchemaFlag) : this(server, serviceProvider, null, groupBySchemaFlag, token)
{ {
} }
internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null) internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
{ {
this.server = server; this.server = server;
ServiceProvider = serviceProvider; ServiceProvider = serviceProvider;
this.smoWrapper = serverManager ?? new SmoWrapper(); this.smoWrapper = serverManager ?? new SmoWrapper();
this.groupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false); this.GroupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false);
if(token != null && !string.IsNullOrEmpty(token.Token))
{
UpdateAccessToken(token.Token);
}
} }
/// <summary> /// <summary>
@@ -85,9 +89,18 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
} }
} }
/// <summary>
/// Function that gets the group by schema
/// </summary>
/// <value></value>
public Func<bool> GroupBySchemaFlag { get; set; }
/// <summary>
/// Returns group by schema flag value.
/// </summary>
public bool GroupBySchema public bool GroupBySchema
{ {
get => groupBySchemaFlag(); get => GroupBySchemaFlag();
} }
/// <summary> /// <summary>
@@ -114,7 +127,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
/// <returns>new <see cref="SmoQueryContext"/> with all fields except <see cref="Parent"/> the same</returns> /// <returns>new <see cref="SmoQueryContext"/> with all fields except <see cref="Parent"/> the same</returns>
public SmoQueryContext CopyWithParent(SmoObjectBase parent) public SmoQueryContext CopyWithParent(SmoObjectBase parent)
{ {
SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.groupBySchemaFlag) SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.GroupBySchemaFlag)
{ {
database = this.Database, database = this.Database,
Parent = parent, Parent = parent,

View File

@@ -3,6 +3,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information. // Licensed under the MIT license. See LICENSE file in the project root for full license information.
// //
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -37,32 +38,16 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
{ {
using (SqlConnection conn = new SqlConnection(connectionString)) using (SqlConnection conn = new SqlConnection(connectionString))
{ {
ServerConnection connection; conn.AccessToken = accessToken?.Token;
if (accessToken != null) conn.Open();
{ ServerConnection connection = new ServerConnection(conn);
connection = new ServerConnection(conn, accessToken as IRenewableToken); connection.AccessToken = accessToken as IRenewableToken;
}
else
{
connection = new ServerConnection(conn);
}
try
{
return await Expand(connection, accessToken, nodePath, serverInfo, options, filters); return await Expand(connection, accessToken, nodePath, serverInfo, options, filters);
} }
finally
{
if (connection.IsOpen)
{
connection.Disconnect();
}
}
}
} }
/// <summary> /// <summary>
/// Expands the node at the given path and returns the child nodes. If parent is not null, it will skip expanding from the top and use the connection from the parent node. /// Expands the node at the given path and returns the child nodes.
/// </summary> /// </summary>
/// <param name="serverConnection"> Server connection to use for expanding the node. It will be used only if parent is null </param> /// <param name="serverConnection"> Server connection to use for expanding the node. It will be used only if parent is null </param>
/// <param name="accessToken"> Access token to connect to the server. To be used in case of AAD based connections </param> /// <param name="accessToken"> Access token to connect to the server. To be used in case of AAD based connections </param>
@@ -70,28 +55,20 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
/// <param name="serverInfo"> Server information </param> /// <param name="serverInfo"> Server information </param>
/// <param name="options"> Object explorer expansion options </param> /// <param name="options"> Object explorer expansion options </param>
/// <param name="filters"> Filters to be applied on the leaf nodes </param> /// <param name="filters"> Filters to be applied on the leaf nodes </param>
/// <param name="parent"> Optional parent node. If provided, it will skip expanding from the top and and use the connection from the parent node </param>
/// <returns></returns> /// <returns></returns>
/// </summary> /// </summary>
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null)
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null, TreeNode parent = null)
{ {
using (var taskCancellationTokenSource = new CancellationTokenSource()) using (var taskCancellationTokenSource = new CancellationTokenSource())
{ {
try try
{ {
var token = accessToken == null ? null : accessToken.Token; var token = accessToken == null ? null : accessToken.Token;
var task = Task.Run(() => var task = Task.Run(() =>
{ {
TreeNode? node; TreeNode? node;
if (parent == null) ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter, accessToken);
{
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter);
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName); TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
if (nodePath == null || nodePath == string.Empty) if (nodePath == null || nodePath == string.Empty)
{ {
nodePath = rootNode.GetNodePath(); nodePath = rootNode.GetNodePath();
@@ -103,12 +80,6 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
return new TreeNode[0]; return new TreeNode[0];
} }
node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token); node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
}
else
{
node = parent;
}
if (node != null) if (node != null)
{ {
return node.Expand(taskCancellationTokenSource.Token, token, filters); return node.Expand(taskCancellationTokenSource.Token, token, filters);
@@ -118,20 +89,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
throw new InvalidArgumentException($"Parent node not found for path {nodePath}"); throw new InvalidArgumentException($"Parent node not found for path {nodePath}");
} }
}); });
return await RunExpandTask(task, taskCancellationTokenSource, options.OperationTimeoutSeconds);
if (await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) == task)
{
if (taskCancellationTokenSource.IsCancellationRequested)
{
throw new TimeoutException("The operation has timed out.");
}
return task.Result.ToArray();
}
else
{
throw new TimeoutException("The operation has timed out.");
}
} }
catch (Exception ex) catch (Exception ex)
{ {
@@ -141,5 +99,59 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
} }
} }
/// <summary>
/// Expands the given node and returns the child nodes.
/// </summary>
/// <param name="node"> Node to expand </param>
/// <param name="options"> Object explorer expansion options </param>
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
/// <param name="securityToken"> Security token to connect to the server. To be used in case of AAD based connections </param>
/// <returns></returns>
public static async Task<TreeNode[]> ExpandTreeNode(TreeNode node, ObjectExplorerOptions options, INodeFilter[]? filters = null, SecurityToken? securityToken = null)
{
if (node == null)
{
throw new ArgumentNullException(nameof(node));
}
using (var taskCancellationTokenSource = new CancellationTokenSource())
{
var expandTask = Task.Run(async () =>
{
SmoQueryContext nodeContext = node.GetContextAs<SmoQueryContext>() ?? throw new ArgumentException("Node does not have a valid context");
if(options.GroupBySchemaFlagGetter != null)
{
nodeContext.GroupBySchemaFlag = options.GroupBySchemaFlagGetter;
}
if (!nodeContext.Server.ConnectionContext.IsOpen && securityToken != null)
{
var underlyingSqlConnection = nodeContext.Server.ConnectionContext.SqlConnectionObject;
underlyingSqlConnection.AccessToken = securityToken.Token;
await underlyingSqlConnection.OpenAsync();
}
return node.Expand(taskCancellationTokenSource.Token, securityToken?.Token, filters);
});
return await RunExpandTask(expandTask, taskCancellationTokenSource, options.OperationTimeoutSeconds);
}
}
private static async Task<TreeNode[]> RunExpandTask(Task<IList<TreeNode>> expansionTask, CancellationTokenSource taskCancellationTokenSource, int operationTimeoutSeconds)
{
if (await Task.WhenAny(expansionTask, Task.Delay(TimeSpan.FromSeconds(operationTimeoutSeconds))) == expansionTask)
{
if (taskCancellationTokenSource.IsCancellationRequested)
{
throw new TimeoutException("The operation has timed out.");
}
return expansionTask.Result.ToArray();
}
else
{
throw new TimeoutException("The operation has timed out.");
}
}
} }
} }

View File

@@ -70,16 +70,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer
var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options); var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo"); Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes[0]); nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes[0], options, null, null);
Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1"); Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Tables")); nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Tables"), options, null, null);
Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1"); Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "dbo.t1")); nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "dbo.t1"), options, null, null);
Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns"); Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Columns")); nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Columns"), options, null, null);
Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1"); Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1");
}); });
} }

View File

@@ -44,7 +44,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
context = new Mock<SmoQueryContext>(new Server(), ExtensionServiceProvider.CreateDefaultServiceProvider(), () => context = new Mock<SmoQueryContext>(new Server(), ExtensionServiceProvider.CreateDefaultServiceProvider(), () =>
{ {
return enableGroupBySchema; return enableGroupBySchema;
}); }, null);
context.CallBase = true; context.CallBase = true;
context.Object.ValidFor = ValidForFlag.None; context.Object.ValidFor = ValidForFlag.None;