mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-01-13 17:23:02 -05:00
Fixing connections for stateless object explorer to work with fabric tokens (#2252)
This commit is contained in:
@@ -12,6 +12,7 @@ using Microsoft.SqlTools.Utility;
|
|||||||
using Microsoft.SqlTools.Extensibility;
|
using Microsoft.SqlTools.Extensibility;
|
||||||
using Microsoft.SqlTools.SqlCore.Utility;
|
using Microsoft.SqlTools.SqlCore.Utility;
|
||||||
using System.IO;
|
using System.IO;
|
||||||
|
using Microsoft.SqlTools.SqlCore.Connection;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||||
{
|
{
|
||||||
@@ -26,7 +27,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
|||||||
private SqlServerType sqlServerType;
|
private SqlServerType sqlServerType;
|
||||||
public ServerConnection serverConnection;
|
public ServerConnection serverConnection;
|
||||||
|
|
||||||
public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null)
|
public ServerNode(ObjectExplorerServerInfo serverInfo, ServerConnection serverConnection, IMultiServiceProvider serviceProvider = null, Func<bool> groupBySchemaFlag = null, SecurityToken? accessToken = null)
|
||||||
: base()
|
: base()
|
||||||
{
|
{
|
||||||
Validate.IsNotNull(nameof(ObjectExplorerServerInfo), serverInfo);
|
Validate.IsNotNull(nameof(ObjectExplorerServerInfo), serverInfo);
|
||||||
@@ -36,9 +37,8 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
|||||||
|
|
||||||
var assembly = typeof(SqlCore.ObjectExplorer.SmoModel.SmoQuerier).Assembly;
|
var assembly = typeof(SqlCore.ObjectExplorer.SmoModel.SmoQuerier).Assembly;
|
||||||
serviceProvider ??= ExtensionServiceProvider.CreateFromAssembliesInDirectory(Path.GetDirectoryName(assembly.Location), new string[] { Path.GetFileName(assembly.Location) });
|
serviceProvider ??= ExtensionServiceProvider.CreateFromAssembliesInDirectory(Path.GetDirectoryName(assembly.Location), new string[] { Path.GetFileName(assembly.Location) });
|
||||||
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag));
|
|
||||||
this.serverConnection = serverConnection;
|
this.serverConnection = serverConnection;
|
||||||
|
this.context = new Lazy<SmoQueryContext>(() => CreateContext(serviceProvider, groupBySchemaFlag, accessToken));
|
||||||
NodeValue = serverInfo.ServerName;
|
NodeValue = serverInfo.ServerName;
|
||||||
IsAlwaysLeaf = false;
|
IsAlwaysLeaf = false;
|
||||||
NodeType = NodeTypes.Server.ToString();
|
NodeType = NodeTypes.Server.ToString();
|
||||||
@@ -114,7 +114,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null)
|
private SmoQueryContext CreateContext(IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
|
||||||
{
|
{
|
||||||
string exceptionMessage;
|
string exceptionMessage;
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
|||||||
Server server = SmoWrapper.CreateServer(this.serverConnection);
|
Server server = SmoWrapper.CreateServer(this.serverConnection);
|
||||||
if (server != null)
|
if (server != null)
|
||||||
{
|
{
|
||||||
return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag)
|
return new SmoQueryContext(server, serviceProvider, SmoWrapper, groupBySchemaFlag, token)
|
||||||
{
|
{
|
||||||
Parent = server,
|
Parent = server,
|
||||||
SqlServerType = this.sqlServerType
|
SqlServerType = this.sqlServerType
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
using System;
|
using System;
|
||||||
using Microsoft.SqlServer.Management.Smo;
|
using Microsoft.SqlServer.Management.Smo;
|
||||||
using Microsoft.SqlTools.Extensibility;
|
using Microsoft.SqlTools.Extensibility;
|
||||||
|
using Microsoft.SqlTools.SqlCore.Connection;
|
||||||
using Microsoft.SqlTools.SqlCore.ObjectExplorer.Nodes;
|
using Microsoft.SqlTools.SqlCore.ObjectExplorer.Nodes;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
||||||
@@ -20,23 +21,26 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
|||||||
private SmoObjectBase parent;
|
private SmoObjectBase parent;
|
||||||
private SmoWrapper smoWrapper;
|
private SmoWrapper smoWrapper;
|
||||||
private ValidForFlag validFor = 0;
|
private ValidForFlag validFor = 0;
|
||||||
private Func<bool> groupBySchemaFlag;
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Creates a context object with a server to use as the basis for any queries
|
/// Creates a context object with a server to use as the basis for any queries
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="server"></param>
|
/// <param name="server"></param>
|
||||||
public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null)
|
public SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
|
||||||
: this(server, serviceProvider, null, groupBySchemaFlag)
|
: this(server, serviceProvider, null, groupBySchemaFlag, token)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null)
|
internal SmoQueryContext(Server server, IMultiServiceProvider serviceProvider, SmoWrapper serverManager, Func<bool> groupBySchemaFlag = null, SecurityToken token = null)
|
||||||
{
|
{
|
||||||
this.server = server;
|
this.server = server;
|
||||||
ServiceProvider = serviceProvider;
|
ServiceProvider = serviceProvider;
|
||||||
this.smoWrapper = serverManager ?? new SmoWrapper();
|
this.smoWrapper = serverManager ?? new SmoWrapper();
|
||||||
this.groupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false);
|
this.GroupBySchemaFlag = groupBySchemaFlag ?? new Func<bool>(() => false);
|
||||||
|
if(token != null && !string.IsNullOrEmpty(token.Token))
|
||||||
|
{
|
||||||
|
UpdateAccessToken(token.Token);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -85,9 +89,18 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Function that gets the group by schema
|
||||||
|
/// </summary>
|
||||||
|
/// <value></value>
|
||||||
|
public Func<bool> GroupBySchemaFlag { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Returns group by schema flag value.
|
||||||
|
/// </summary>
|
||||||
public bool GroupBySchema
|
public bool GroupBySchema
|
||||||
{
|
{
|
||||||
get => groupBySchemaFlag();
|
get => GroupBySchemaFlag();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -114,7 +127,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer.SmoModel
|
|||||||
/// <returns>new <see cref="SmoQueryContext"/> with all fields except <see cref="Parent"/> the same</returns>
|
/// <returns>new <see cref="SmoQueryContext"/> with all fields except <see cref="Parent"/> the same</returns>
|
||||||
public SmoQueryContext CopyWithParent(SmoObjectBase parent)
|
public SmoQueryContext CopyWithParent(SmoObjectBase parent)
|
||||||
{
|
{
|
||||||
SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.groupBySchemaFlag)
|
SmoQueryContext context = new SmoQueryContext(this.Server, this.ServiceProvider, this.smoWrapper, this.GroupBySchemaFlag)
|
||||||
{
|
{
|
||||||
database = this.Database,
|
database = this.Database,
|
||||||
Parent = parent,
|
Parent = parent,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
//
|
//
|
||||||
using System;
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
using System.Linq;
|
using System.Linq;
|
||||||
using System.Threading;
|
using System.Threading;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
@@ -37,32 +38,16 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
|||||||
{
|
{
|
||||||
using (SqlConnection conn = new SqlConnection(connectionString))
|
using (SqlConnection conn = new SqlConnection(connectionString))
|
||||||
{
|
{
|
||||||
ServerConnection connection;
|
conn.AccessToken = accessToken?.Token;
|
||||||
if (accessToken != null)
|
conn.Open();
|
||||||
{
|
ServerConnection connection = new ServerConnection(conn);
|
||||||
connection = new ServerConnection(conn, accessToken as IRenewableToken);
|
connection.AccessToken = accessToken as IRenewableToken;
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
connection = new ServerConnection(conn);
|
|
||||||
}
|
|
||||||
|
|
||||||
try
|
|
||||||
{
|
|
||||||
return await Expand(connection, accessToken, nodePath, serverInfo, options, filters);
|
return await Expand(connection, accessToken, nodePath, serverInfo, options, filters);
|
||||||
}
|
}
|
||||||
finally
|
|
||||||
{
|
|
||||||
if (connection.IsOpen)
|
|
||||||
{
|
|
||||||
connection.Disconnect();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Expands the node at the given path and returns the child nodes. If parent is not null, it will skip expanding from the top and use the connection from the parent node.
|
/// Expands the node at the given path and returns the child nodes.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="serverConnection"> Server connection to use for expanding the node. It will be used only if parent is null </param>
|
/// <param name="serverConnection"> Server connection to use for expanding the node. It will be used only if parent is null </param>
|
||||||
/// <param name="accessToken"> Access token to connect to the server. To be used in case of AAD based connections </param>
|
/// <param name="accessToken"> Access token to connect to the server. To be used in case of AAD based connections </param>
|
||||||
@@ -70,28 +55,20 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
|||||||
/// <param name="serverInfo"> Server information </param>
|
/// <param name="serverInfo"> Server information </param>
|
||||||
/// <param name="options"> Object explorer expansion options </param>
|
/// <param name="options"> Object explorer expansion options </param>
|
||||||
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
|
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
|
||||||
/// <param name="parent"> Optional parent node. If provided, it will skip expanding from the top and and use the connection from the parent node </param>
|
|
||||||
/// <returns></returns>
|
/// <returns></returns>
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null)
|
||||||
|
|
||||||
public static async Task<TreeNode[]> Expand(ServerConnection serverConnection, SecurityToken? accessToken, string? nodePath, ObjectExplorerServerInfo serverInfo, ObjectExplorerOptions options, INodeFilter[]? filters = null, TreeNode parent = null)
|
|
||||||
{
|
{
|
||||||
using (var taskCancellationTokenSource = new CancellationTokenSource())
|
using (var taskCancellationTokenSource = new CancellationTokenSource())
|
||||||
{
|
{
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var token = accessToken == null ? null : accessToken.Token;
|
var token = accessToken == null ? null : accessToken.Token;
|
||||||
|
|
||||||
var task = Task.Run(() =>
|
var task = Task.Run(() =>
|
||||||
{
|
{
|
||||||
TreeNode? node;
|
TreeNode? node;
|
||||||
if (parent == null)
|
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter, accessToken);
|
||||||
{
|
|
||||||
ServerNode serverNode = new ServerNode(serverInfo, serverConnection, null, options.GroupBySchemaFlagGetter);
|
|
||||||
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
|
TreeNode rootNode = new DatabaseTreeNode(serverNode, serverInfo.DatabaseName);
|
||||||
|
|
||||||
if (nodePath == null || nodePath == string.Empty)
|
if (nodePath == null || nodePath == string.Empty)
|
||||||
{
|
{
|
||||||
nodePath = rootNode.GetNodePath();
|
nodePath = rootNode.GetNodePath();
|
||||||
@@ -103,12 +80,6 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
|||||||
return new TreeNode[0];
|
return new TreeNode[0];
|
||||||
}
|
}
|
||||||
node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
|
node = rootNode.FindNodeByPath(nodePath, true, taskCancellationTokenSource.Token);
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
node = parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node != null)
|
if (node != null)
|
||||||
{
|
{
|
||||||
return node.Expand(taskCancellationTokenSource.Token, token, filters);
|
return node.Expand(taskCancellationTokenSource.Token, token, filters);
|
||||||
@@ -118,20 +89,7 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
|||||||
throw new InvalidArgumentException($"Parent node not found for path {nodePath}");
|
throw new InvalidArgumentException($"Parent node not found for path {nodePath}");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
return await RunExpandTask(task, taskCancellationTokenSource, options.OperationTimeoutSeconds);
|
||||||
|
|
||||||
if (await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(options.OperationTimeoutSeconds))) == task)
|
|
||||||
{
|
|
||||||
if (taskCancellationTokenSource.IsCancellationRequested)
|
|
||||||
{
|
|
||||||
throw new TimeoutException("The operation has timed out.");
|
|
||||||
}
|
|
||||||
return task.Result.ToArray();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
throw new TimeoutException("The operation has timed out.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
@@ -141,5 +99,59 @@ namespace Microsoft.SqlTools.SqlCore.ObjectExplorer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Expands the given node and returns the child nodes.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="node"> Node to expand </param>
|
||||||
|
/// <param name="options"> Object explorer expansion options </param>
|
||||||
|
/// <param name="filters"> Filters to be applied on the leaf nodes </param>
|
||||||
|
/// <param name="securityToken"> Security token to connect to the server. To be used in case of AAD based connections </param>
|
||||||
|
/// <returns></returns>
|
||||||
|
public static async Task<TreeNode[]> ExpandTreeNode(TreeNode node, ObjectExplorerOptions options, INodeFilter[]? filters = null, SecurityToken? securityToken = null)
|
||||||
|
{
|
||||||
|
if (node == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
using (var taskCancellationTokenSource = new CancellationTokenSource())
|
||||||
|
{
|
||||||
|
var expandTask = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
SmoQueryContext nodeContext = node.GetContextAs<SmoQueryContext>() ?? throw new ArgumentException("Node does not have a valid context");
|
||||||
|
|
||||||
|
if(options.GroupBySchemaFlagGetter != null)
|
||||||
|
{
|
||||||
|
nodeContext.GroupBySchemaFlag = options.GroupBySchemaFlagGetter;
|
||||||
|
}
|
||||||
|
if (!nodeContext.Server.ConnectionContext.IsOpen && securityToken != null)
|
||||||
|
{
|
||||||
|
var underlyingSqlConnection = nodeContext.Server.ConnectionContext.SqlConnectionObject;
|
||||||
|
underlyingSqlConnection.AccessToken = securityToken.Token;
|
||||||
|
await underlyingSqlConnection.OpenAsync();
|
||||||
|
}
|
||||||
|
|
||||||
|
return node.Expand(taskCancellationTokenSource.Token, securityToken?.Token, filters);
|
||||||
|
});
|
||||||
|
|
||||||
|
return await RunExpandTask(expandTask, taskCancellationTokenSource, options.OperationTimeoutSeconds);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static async Task<TreeNode[]> RunExpandTask(Task<IList<TreeNode>> expansionTask, CancellationTokenSource taskCancellationTokenSource, int operationTimeoutSeconds)
|
||||||
|
{
|
||||||
|
if (await Task.WhenAny(expansionTask, Task.Delay(TimeSpan.FromSeconds(operationTimeoutSeconds))) == expansionTask)
|
||||||
|
{
|
||||||
|
if (taskCancellationTokenSource.IsCancellationRequested)
|
||||||
|
{
|
||||||
|
throw new TimeoutException("The operation has timed out.");
|
||||||
|
}
|
||||||
|
return expansionTask.Result.ToArray();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw new TimeoutException("The operation has timed out.");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -70,16 +70,16 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer
|
|||||||
var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
|
var nodes = await StatelessObjectExplorer.Expand(connectionString, null, pathWithDb, serverInfo, options);
|
||||||
Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo");
|
Assert.True(nodes.Any(node => node.Label == "dbo"), $"Expansion result for {pathWithDb} does not contain node dbo");
|
||||||
|
|
||||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes[0]);
|
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes[0], options, null, null);
|
||||||
Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1");
|
Assert.True(nodes.Any(node => node.Label == "Tables"), $"Expansion result for {pathWithDb} does not contain node t1");
|
||||||
|
|
||||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Tables"));
|
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Tables"), options, null, null);
|
||||||
Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1");
|
Assert.True(nodes.Any(node => node.Label == "dbo.t1"), $"Expansion result for {pathWithDb} does not contain node t1");
|
||||||
|
|
||||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "dbo.t1"));
|
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "dbo.t1"), options, null, null);
|
||||||
Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns");
|
Assert.True(nodes.Any(node => node.Label == "Columns"), $"Expansion result for {pathWithDb} does not contain node Columns");
|
||||||
|
|
||||||
nodes = await StatelessObjectExplorer.Expand(null, null, null, serverInfo, options, null, nodes.First(node => node.Label == "Columns"));
|
nodes = await StatelessObjectExplorer.ExpandTreeNode(nodes.First(node => node.Label == "Columns"), options, null, null);
|
||||||
Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1");
|
Assert.True(nodes.Any(node => node.Label == "c1 (int, null)"), $"Expansion result for {pathWithDb} does not contain node c1");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer
|
|||||||
context = new Mock<SmoQueryContext>(new Server(), ExtensionServiceProvider.CreateDefaultServiceProvider(), () =>
|
context = new Mock<SmoQueryContext>(new Server(), ExtensionServiceProvider.CreateDefaultServiceProvider(), () =>
|
||||||
{
|
{
|
||||||
return enableGroupBySchema;
|
return enableGroupBySchema;
|
||||||
});
|
}, null);
|
||||||
context.CallBase = true;
|
context.CallBase = true;
|
||||||
context.Object.ValidFor = ValidForFlag.None;
|
context.Object.ValidFor = ValidForFlag.None;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user