Fixing connections for stateless object explorer to work with fabric tokens (#2252)

This commit is contained in:
Aasim Khan
2023-09-27 02:07:36 +00:00
committed by GitHub
parent 78581e8508
commit 138a70efa0
5 changed files with 102 additions and 77 deletions

View File

@@ -12,6 +12,7 @@ using Microsoft.SqlTools.Utility;
using Microsoft.SqlTools.Extensibility;
using Microsoft.SqlTools.SqlCore.Utility;
using System.IO;
using Microsoft.SqlTools.SqlCore.Connection;
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
{
@@ -26,7 +27,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
private SqlServerType sqlServerType;
public ServerConnection serverConnection;
public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null)
public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null, SecurityToken? accessToken = null)
: base()
{
Validate.IsNotNull(nameof(ObjectExplorerServerInfo), serverInfo);
@@ -36,9 +37,8 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
var assembly = typeof(SqlCore.ObjectExplorer.SmoModel.SmoQuerier).Assembly;
serviceProvider ??= ExtensionServiceProvider.CreateFromAssembliesInDirectory(Path.GetDirectoryName(assembly.Location), new string[] { Path.GetFileName(assembly.Location) });
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag));
this.serverConnection = serverConnection;
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag, accessToken));
NodeValue = serverInfo.ServerName;
IsAlwaysLeaf = false;
NodeType = NodeTypes.Server.ToString();
@@ -114,7 +114,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null)
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
{
string exceptionMessage;
@@ -123,7 +123,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
Server server = SmoWrapper.CreateServer(this.serverConnection);
if (server != null)
{
return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag)
return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag, token)
{
Parent = server,
SqlServerType = this.sqlServerType

View File

@@ -6,6 +6,7 @@
using System;
using Microsoft.SqlServer.Management.Smo;
using Microsoft.SqlTools.Extensibility;
using Microsoft.SqlTools.SqlCore.Connection;
using Microsoft.SqlTools.SqlCore.ObjectExplorer.Nodes;
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
@@ -20,23 +21,26 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
private SmoObjectBase parent;
private SmoWrapper smoWrapper;
private ValidForFlag validFor = 0;
private Func<bool> groupBySchemaFlag;
/// <summary>
/// Creates a context object with a server to use as the basis for any queries
/// </summary>
/// <param name="server"></param>
public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null)
: this(server, serviceProvider, null, groupBySchemaFlag)
public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
: this(server, serviceProvider, null, groupBySchemaFlag, token)
{
}
internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null)
internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
{
this.server = server;
ServiceProvider = serviceProvider;
this.smoWrapper = serverManager ?? new SmoWrapper();
this.groupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false);
this.GroupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false);
if(token != null && !string.IsNullOrEmpty(token.Token))
{
UpdateAccessToken(token.Token);
}
}
/// <summary>
@@ -85,9 +89,18 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
}
}
/// <summary>
/// Function that gets the group by schema
/// </summary>
/// <value></value>
public Func<bool> GroupBySchemaFlag { get; set; }
/// <summary>
/// Returns group by schema flag value.
/// </summary>
public bool GroupBySchema
{
get => groupBySchemaFlag();
get => GroupBySchemaFlag();
}
/// <summary>
@@ -114,7 +127,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
/// <returns>new <see cref="SmoQueryContext"/> with all fields except <see cref="Parent"/> the same</returns>
public SmoQueryContext CopyWithParent(SmoObjectBase parent)
{
SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.groupBySchemaFlag)
SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.GroupBySchemaFlag)
{
database = this.Database,
Parent = parent,

View File

@@ -3,6 +3,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
@@ -37,32 +38,16 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
{
using (SqlConnection conn = new SqlConnection(connectionString))
{
ServerConnection connection;
if (accessToken != null)
{
connection = new ServerConnection(conn, accessToken as IRenewableToken);
}
else
{
connection = new ServerConnection(conn);
}
try
{
conn.AccessToken = accessToken?.Token;
conn.Open();
ServerConnection connection = new ServerConnection(conn);
connection.AccessToken = accessToken as IRenewableToken;
return await Expand(connection, accessToken, nodePath, serverInfo, options, filters);
}
finally
{
if (connection.IsOpen)
{
connection.Disconnect();
}
}
}
}
/// <summary>
/// Expands the node at the given path and returns the child nodes. If parent is not null, it will skip expanding from the top and use the connection from the parent node.
/// Expands the node at the given path and returns the child nodes.
/// </summary>
/// <param name="serverConnection"> Server connection to use for expanding the node. It will be used only if parent is null </param>
/// <param name="accessToken"> Access token to connect to the server. To be used in case of AAD based connections </param>
@@ -70,28 +55,20 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
/// <param name="serverInfo"> Server information </param>
/// <param name="options"> Object explorer expansion options </param>
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
/// <param name="parent"> Optional parent node. If provided, it will skip expanding from the top and and use the connection from the parent node </param>
/// <returns></returns>
/// </summary>
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null, TreeNode parent = null)
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null)
{
using (var taskCancellationTokenSource = new CancellationTokenSource())
{
try
{
var token = accessToken == null ? null : accessToken.Token;
var task = Task.Run(() =>
{
TreeNode? node;
if (parent == null)
{
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter);
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter, accessToken);
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
if (nodePath == null || nodePath == string.Empty)
{
nodePath = rootNode.GetNodePath();
@@ -103,12 +80,6 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
return new TreeNode[0];
}
node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
}
else
{
node = parent;
}
if (node != null)
{
return node.Expand(taskCancellationTokenSource.Token, token, filters);
@@ -118,20 +89,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
throw new InvalidArgumentException($"Parent node not found for path {nodePath}");
}
});
if (await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) == task)
{
if (taskCancellationTokenSource.IsCancellationRequested)
{
throw new TimeoutException("The operation has timed out.");
}
return task.Result.ToArray();
}
else
{
throw new TimeoutException("The operation has timed out.");
}
return await RunExpandTask(task, taskCancellationTokenSource, options.OperationTimeoutSeconds);
}
catch (Exception ex)
{
@@ -141,5 +99,59 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
}
}
/// <summary>
/// Expands the given node and returns the child nodes.
/// </summary>
/// <param name="node"> Node to expand </param>
/// <param name="options"> Object explorer expansion options </param>
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
/// <param name="securityToken"> Security token to connect to the server. To be used in case of AAD based connections </param>
/// <returns></returns>
public static async Task<TreeNode[]> ExpandTreeNode(TreeNode node, ObjectExplorerOptions options, INodeFilter[]? filters = null, SecurityToken? securityToken = null)
{
if (node == null)
{
throw new ArgumentNullException(nameof(node));
}
using (var taskCancellationTokenSource = new CancellationTokenSource())
{
var expandTask = Task.Run(async () =>
{
SmoQueryContext nodeContext = node.GetContextAs<SmoQueryContext>() ?? throw new ArgumentException("Node does not have a valid context");
if(options.GroupBySchemaFlagGetter != null)
{
nodeContext.GroupBySchemaFlag = options.GroupBySchemaFlagGetter;
}
if (!nodeContext.Server.ConnectionContext.IsOpen && securityToken != null)
{
var underlyingSqlConnection = nodeContext.Server.ConnectionContext.SqlConnectionObject;
underlyingSqlConnection.AccessToken = securityToken.Token;
await underlyingSqlConnection.OpenAsync();
}
return node.Expand(taskCancellationTokenSource.Token, securityToken?.Token, filters);
});
return await RunExpandTask(expandTask, taskCancellationTokenSource, options.OperationTimeoutSeconds);
}
}
private static async Task<TreeNode[]> RunExpandTask(Task<IList<TreeNode>> expansionTask, CancellationTokenSource taskCancellationTokenSource, int operationTimeoutSeconds)
{
if (await Task.WhenAny(expansionTask, Task.Delay(TimeSpan.FromSeconds(operationTimeoutSeconds))) == expansionTask)
{
if (taskCancellationTokenSource.IsCancellationRequested)
{
throw new TimeoutException("The operation has timed out.");
}
return expansionTask.Result.ToArray();
}
else
{
throw new TimeoutException("The operation has timed out.");
}
}
}
}

View File

@@ -70,16 +70,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer
var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes[0]);
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes[0], options, null, null);
Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Tables"));
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Tables"), options, null, null);
Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "dbo.t1"));
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "dbo.t1"), options, null, null);
Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns");
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Columns"));
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Columns"), options, null, null);
Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1");
});
}

View File

@@ -44,7 +44,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
context = new Mock<SmoQueryContext>(new Server(), ExtensionServiceProvider.CreateDefaultServiceProvider(), () =>
{
return enableGroupBySchema;
});
}, null);
context.CallBase = true;
context.Object.ValidFor = ValidForFlag.None;