diff --git a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj index 2138a933..55f285d1 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj +++ b/src/Microsoft.SqlTools.ServiceLayer/Microsoft.SqlTools.ServiceLayer.csproj @@ -29,6 +29,7 @@ + diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/Contracts/FindNodesRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/Contracts/FindNodesRequest.cs new file mode 100644 index 00000000..8aed733c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/Contracts/FindNodesRequest.cs @@ -0,0 +1,55 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using Microsoft.SqlTools.Hosting.Protocol.Contracts; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Contracts +{ + /// + /// Information returned from a . + /// + public class FindNodesResponse + { + /// + /// Information describing the matching nodes in the tree + /// + public List Nodes { get; set; } + } + + /// + /// Parameters to the . + /// + public class FindNodesParams + { + /// + /// The Id returned from a . This + /// is used to disambiguate between different trees. + /// + public string SessionId { get; set; } + + public string Type { get; set; } + + public string Schema { get; set; } + + public string Name { get; set; } + + public string Database { get; set; } + + public List ParentObjectNames { get; set; } + + } + + /// + /// TODO + /// + public class FindNodesRequest + { + public static readonly + RequestType Type = + RequestType.Create("objectexplorer/findnodes"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/Nodes/TreeNode.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/Nodes/TreeNode.cs index 6ee47476..1475f80d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/Nodes/TreeNode.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/Nodes/TreeNode.cs @@ -181,7 +181,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes nodePath = path; } - public TreeNode FindNodeByPath(string path) + public TreeNode FindNodeByPath(string path, bool refreshChildren = false) { TreeNode nodeForPath = ObjectExplorerUtils.FindNode(this, node => { @@ -189,7 +189,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes }, nodeToFilter => { return path.StartsWith(nodeToFilter.GetNodePath()); - }); + }, refreshChildren); return nodeForPath; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs index 4e598229..7bd85898 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs @@ -47,7 +47,6 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer private ConnectedBindingQueue bindingQueue = new ConnectedBindingQueue(needsMetadata: false); private string connectionName = "ObjectExplorer"; - /// /// This timeout limits the amount of time that object explorer tasks can take to complete /// @@ -60,6 +59,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer { sessionMap = new ConcurrentDictionary(); applicableNodeChildFactories = new Lazy>>(() => PopulateFactories()); + NodePathGenerator.Initialize(); } internal ConnectedBindingQueue ConnectedBindingQueue @@ -136,6 +136,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer serviceHost.SetRequestHandler(ExpandRequest.Type, HandleExpandRequest); serviceHost.SetRequestHandler(RefreshRequest.Type, HandleRefreshRequest); serviceHost.SetRequestHandler(CloseSessionRequest.Type, HandleCloseSessionRequest); + serviceHost.SetRequestHandler(FindNodesRequest.Type, HandleFindNodesRequest); WorkspaceService workspaceService = WorkspaceService; if (workspaceService != null) { @@ -293,6 +294,16 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer await HandleRequestAsync(closeSession, context, "HandleCloseSessionRequest"); } + internal async Task HandleFindNodesRequest(FindNodesParams findNodesParams, RequestContext context) + { + var foundNodes = FindNodes(findNodesParams.SessionId, findNodesParams.Type, findNodesParams.Schema, findNodesParams.Name, findNodesParams.Database, findNodesParams.ParentObjectNames); + if (foundNodes == null) + { + foundNodes = new List(); + } + await context.SendResult(new FindNodesResponse { Nodes = foundNodes.Select(node => node.ToNodeInfo()).ToList() }); + } + internal void CloseSession(string uri) { ObjectExplorerSession session; @@ -689,6 +700,37 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer applicableFactories.Add(factory); } + /// + /// Find all tree nodes matching the given node information + /// + /// The ID of the object explorer session to find nodes for + /// The requested node type + /// The schema for the requested object, or null if not applicable + /// The name of the requested object + /// The name of the database containing the requested object, or null if not applicable + /// The name of any other parent objects in the object explorer tree, from highest in the tree to lowest + /// A list of nodes matching the given information, or an empty list if no nodes match + public List FindNodes(string sessionId, string typeName, string schema, string name, string databaseName, List parentNames = null) + { + var nodes = new List(); + var oeSession = sessionMap.GetValueOrDefault(sessionId); + if (oeSession == null) + { + return nodes; + } + + var outputPaths = NodePathGenerator.FindNodePaths(oeSession, typeName, schema, name, databaseName, parentNames); + foreach (var outputPath in outputPaths) + { + var treeNode = oeSession.Root.FindNodeByPath(outputPath, true); + if (treeNode != null) + { + nodes.Add(treeNode); + } + } + return nodes; + } + internal class ObjectExplorerTaskResult { public bool IsCompleted { get; set; } diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs index c2ed4c34..5c635f0f 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs @@ -41,37 +41,38 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer return VisitChildAndParents(child.Parent, visitor); } - /// - /// Finds a node by traversing the tree starting from the given node through all the children - /// - /// node to start traversing at + /// + /// Finds a node by traversing the tree starting from the given node through all the children + /// + /// node to start traversing at /// Predicate function that accesses the tree and - /// determines whether to stop going further up the tree - /// Predicate function to filter the children when traversing + /// determines whether to stop going further up the tree + /// Predicate function to filter the children when traversing /// A Tree Node that matches the condition - public static TreeNode FindNode(TreeNode node, Predicate condition, Predicate filter) - { - if(node == null) - { - return null; - } - - if (condition(node)) - { - return node; - } - foreach (var child in node.GetChildren()) - { - if (filter != null && filter(child)) - { - TreeNode childNode = FindNode(child, condition, filter); - if (childNode != null) - { - return childNode; - } - } - } - return null; - } + public static TreeNode FindNode(TreeNode node, Predicate condition, Predicate filter, bool refreshChildren = false) + { + if(node == null) + { + return null; + } + + if (condition(node)) + { + return node; + } + var children = refreshChildren && !node.IsAlwaysLeaf ? node.Refresh() : node.GetChildren(); + foreach (var child in children) + { + if (filter != null && filter(child)) + { + TreeNode childNode = FindNode(child, condition, filter, refreshChildren); + if (childNode != null) + { + return childNode; + } + } + } + return null; + } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/SmoModel/NodePathGenerator.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/SmoModel/NodePathGenerator.cs new file mode 100644 index 00000000..3af07e7c --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/SmoModel/NodePathGenerator.cs @@ -0,0 +1,268 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Xml.Serialization; + +namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel +{ + public class NodePathGenerator + { + private static ServerExplorerTree TreeRoot { get; set; } + + private static Dictionary> NodeTypeDictionary { get; set; } + + internal static void Initialize() + { + if (TreeRoot != null) + { + return; + } + + var assembly = typeof(ObjectExplorerService).Assembly; + var resource = assembly.GetManifestResourceStream("Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel.TreeNodeDefinition.xml"); + var serializer = new XmlSerializer(typeof(ServerExplorerTree)); + NodeTypeDictionary = new Dictionary>(); + using (var reader = new StreamReader(resource)) + { + TreeRoot = (ServerExplorerTree)serializer.Deserialize(reader); + } + + foreach (var node in TreeRoot.Nodes) + { + var containedType = node.ContainedType(); + if (containedType != null && node.Label() != string.Empty) + { + if (!NodeTypeDictionary.ContainsKey(containedType)) + { + NodeTypeDictionary.Add(containedType, new HashSet()); + } + NodeTypeDictionary.GetValueOrDefault(containedType).Add(node); + } + } + var serverNode = TreeRoot.Nodes.FirstOrDefault(node => node.Name == "Server"); + var serverSet = new HashSet(); + serverSet.Add(serverNode); + NodeTypeDictionary.Add("Server", serverSet); + } + + internal static HashSet FindNodePaths(ObjectExplorerService.ObjectExplorerSession objectExplorerSession, string typeName, string schema, string name, string databaseName, List parentNames = null) + { + if (TreeRoot == null) + { + Initialize(); + } + + var returnSet = new HashSet(); + var matchingNodes = NodeTypeDictionary.GetValueOrDefault(typeName); + if (matchingNodes == null) + { + return returnSet; + } + + var path = name; + if (schema != null) + { + path = schema + "." + path; + } + + if (path == null) + { + path = ""; + } + + foreach (var matchingNode in matchingNodes) + { + var paths = GenerateNodePath(objectExplorerSession, matchingNode, databaseName, parentNames, path); + foreach (var newPath in paths) + { + returnSet.Add(newPath); + } + } + return returnSet; + } + + private static HashSet GenerateNodePath(ObjectExplorerService.ObjectExplorerSession objectExplorerSession, Node currentNode, string databaseName, List parentNames, string path) + { + if (parentNames != null) + { + parentNames = parentNames.ToList(); + } + + if (currentNode.Name == "Server" || (currentNode.Name == "Database" && objectExplorerSession.Root.NodeType == "Database")) + { + var serverRoot = objectExplorerSession.Root; + if (objectExplorerSession.Root.NodeType == "Database") + { + serverRoot = objectExplorerSession.Root.Parent; + path = objectExplorerSession.Root.NodeValue + (path.Length > 0 ? ("/" + path) : ""); + } + + path = serverRoot.NodeValue + (path.Length > 0 ? ("/" + path) : ""); + var returnSet = new HashSet(); + returnSet.Add(path); + return returnSet; + } + + var currentLabel = currentNode.Label(); + if (currentLabel != string.Empty) + { + path = currentLabel + "/" + path; + var returnSet = new HashSet(); + foreach (var parent in currentNode.ParentNodes()) + { + var paths = GenerateNodePath(objectExplorerSession, parent, databaseName, parentNames, path); + foreach (var newPath in paths) + { + returnSet.Add(newPath); + } + } + return returnSet; + } + else + { + var returnSet = new HashSet(); + if (currentNode.ContainedType() == "Database") + { + path = databaseName + "/" + path; + } + else if (parentNames != null && parentNames.Count > 0) + { + var parentName = parentNames.Last(); + parentNames.RemoveAt(parentNames.Count - 1); + path = parentName + "/" + path; + } + else + { + return returnSet; + } + + foreach (var parentNode in currentNode.ParentNodes()) + { + var newPaths = GenerateNodePath(objectExplorerSession, parentNode, databaseName, parentNames, path); + foreach (var newPath in newPaths) + { + returnSet.Add(newPath); + } + } + + return returnSet; + } + } + + [XmlRoot("ServerExplorerTree")] + public class ServerExplorerTree + { + [XmlElement("Node", typeof(Node))] + public List Nodes { get; set; } + + public Node GetNode(string name) + { + foreach (var node in this.Nodes) + { + if (node.Name == name) + { + return node; + } + } + + return null; + } + } + + public class Node + { + [XmlAttribute] + public string Name { get; set; } + + [XmlAttribute] + public string LocLabel { get; set; } + + [XmlAttribute] + public string TreeNode { get; set; } + + [XmlAttribute] + public string NodeType { get; set; } + + [XmlElement("Child", typeof(Child))] + public List Children { get; set; } + + public HashSet ChildFolders() + { + var childSet = new HashSet(); + foreach (var child in this.Children) + { + var node = TreeRoot.GetNode(child.Name); + if (node != null) + { + childSet.Add(node); + } + } + return childSet; + } + + public string ContainedType() + { + if (this.TreeNode != null) + { + return this.TreeNode.Replace("TreeNode", ""); + } + else if (this.NodeType != null) + { + return this.NodeType; + } + return null; + } + + public Node ContainedObject() + { + var containedType = this.ContainedType(); + if (containedType == null) + { + return null; + } + + var containedNode = TreeRoot.GetNode(containedType); + if (containedNode == this) + { + return null; + } + + return containedNode; + } + + public string Label() + { + if (this.LocLabel.StartsWith("SR.")) + { + return SR.Keys.GetString(this.LocLabel.Remove(0, 3)); + } + + return string.Empty; + } + + public HashSet ParentNodes() + { + var parentNodes = new HashSet(); + foreach (var node in TreeRoot.Nodes) + { + if (this != node && (node.ContainedType() == this.Name || node.Children.Any(child => child.Name == this.Name))) + { + parentNodes.Add(node); + } + } + return parentNodes; + } + } + + public class Child + { + [XmlAttribute] + public string Name { get; set; } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/NodePathGeneratorTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/NodePathGeneratorTests.cs new file mode 100644 index 00000000..fd9487aa --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/NodePathGeneratorTests.cs @@ -0,0 +1,164 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.SqlClient; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.ObjectExplorer; +using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Contracts; +using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes; +using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel; +using Microsoft.SqlTools.ServiceLayer.UnitTests.Utility; +using Moq; +using Xunit; +using Microsoft.SqlTools.ServiceLayer.LanguageServices; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlTools.ServiceLayer.Test.Common.RequestContextMocking; + +namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer +{ + + public class NodePathGeneratorTests + { + private ObjectExplorerService.ObjectExplorerSession serverSession; + private ObjectExplorerService.ObjectExplorerSession databaseSession; + private const string serverName = "testServer"; + private const string databaseName = "testDatabase"; + + public NodePathGeneratorTests() + { + var serverRoot = new TreeNode + { + NodeType = "Server", + NodeValue = serverName + }; + + serverSession = new ObjectExplorerService.ObjectExplorerSession("serverUri", serverRoot, null, null); + + var databaseRoot = new TreeNode + { + NodeType = "Database", + NodeValue = databaseName, + Parent = serverRoot + }; + + databaseSession = new ObjectExplorerService.ObjectExplorerSession("databaseUri", databaseRoot, null, null); + } + + [Fact] + public void FindCorrectPathsForTableWithServerRoot() + { + var paths = NodePathGenerator.FindNodePaths(serverSession, "Table", "testSchema", "testTable", databaseName); + var expectedPaths = new List + { + "testServer/Databases/testDatabase/Tables/testSchema.testTable", + "testServer/Databases/System Databases/testDatabase/Tables/testSchema.testTable", + "testServer/Databases/testDatabase/Tables/System Tables/testSchema.testTable", + "testServer/Databases/System Databases/testDatabase/Tables/System Tables/testSchema.testTable" + }; + + Assert.Equal(expectedPaths.Count, paths.Count); + foreach (var expectedPath in expectedPaths) + { + Assert.True(paths.Contains(expectedPath)); + } + } + + [Fact] + public void FindCorrectPathsForTableWithDatabaseRoot() + { + var paths = NodePathGenerator.FindNodePaths(databaseSession, "Table", "testSchema", "testTable", null); + var expectedPaths = new List + { + "testServer/testDatabase/Tables/testSchema.testTable", + "testServer/testDatabase/Tables/System Tables/testSchema.testTable" + }; + + Assert.Equal(expectedPaths.Count, paths.Count); + foreach (var expectedPath in expectedPaths) + { + Assert.True(paths.Contains(expectedPath)); + } + } + + [Fact] + public void FindCorrectPathsForColumnWithServerRoot() + { + var paths = NodePathGenerator.FindNodePaths(serverSession, "Column", null, "testColumn", databaseName, new List { "testSchema.testTable" }); + var expectedPaths = new List + { + "testServer/Databases/testDatabase/Tables/testSchema.testTable/Columns/testColumn", + "testServer/Databases/System Databases/testDatabase/Tables/testSchema.testTable/Columns/testColumn", + "testServer/Databases/testDatabase/Tables/System Tables/testSchema.testTable/Columns/testColumn", + "testServer/Databases/System Databases/testDatabase/Tables/System Tables/testSchema.testTable/Columns/testColumn", + "testServer/Databases/testDatabase/Views/testSchema.testTable/Columns/testColumn", + "testServer/Databases/System Databases/testDatabase/Views/testSchema.testTable/Columns/testColumn", + "testServer/Databases/testDatabase/Views/System Views/testSchema.testTable/Columns/testColumn", + "testServer/Databases/System Databases/testDatabase/Views/System Views/testSchema.testTable/Columns/testColumn" + }; + + Assert.Equal(expectedPaths.Count, paths.Count); + foreach (var expectedPath in expectedPaths) + { + Assert.True(paths.Contains(expectedPath)); + } + } + + [Fact] + public void FindCorrectPathsForColumnWithDatabaseRoot() + { + var paths = NodePathGenerator.FindNodePaths(databaseSession, "Column", null, "testColumn", databaseName, new List { "testSchema.testTable" }); + var expectedPaths = new List + { + "testServer/testDatabase/Tables/testSchema.testTable/Columns/testColumn", + "testServer/testDatabase/Tables/System Tables/testSchema.testTable/Columns/testColumn", + "testServer/testDatabase/Views/testSchema.testTable/Columns/testColumn", + "testServer/testDatabase/Views/System Views/testSchema.testTable/Columns/testColumn" + }; + + Assert.Equal(expectedPaths.Count, paths.Count); + foreach (var expectedPath in expectedPaths) + { + Assert.True(paths.Contains(expectedPath)); + } + } + + [Fact] + public void FindCorrectPathsForDatabase() + { + var paths = NodePathGenerator.FindNodePaths(serverSession, "Database", null, databaseName, null); + var expectedPaths = new List + { + "testServer/Databases/testDatabase", + "testServer/Databases/System Databases/testDatabase" + }; + + Assert.Equal(expectedPaths.Count, paths.Count); + foreach (var expectedPath in expectedPaths) + { + Assert.True(paths.Contains(expectedPath)); + } + } + + [Fact] + public void FindPathForInvalidTypeReturnsEmpty() + { + var serverPaths = NodePathGenerator.FindNodePaths(serverSession, "WrongType", "testSchema", "testTable", databaseName); + Assert.Equal(0, serverPaths.Count); + } + + [Fact] + public void FindPathMissingParentReturnsEmpty() + { + var serverPaths = NodePathGenerator.FindNodePaths(serverSession, "Column", "testSchema", "testColumn", databaseName); + Assert.Equal(0, serverPaths.Count); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs index 722e74c1..d538c7c4 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs @@ -243,6 +243,26 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer connectionServiceMock.Verify(c => c.Disconnect(It.IsAny())); } + [Fact] + public async Task FindNodesReturnsMatchingNode() + { + var session = await CreateSession(); + + var foundNodes = service.FindNodes(session.SessionId, "Server", null, null, null); + Assert.Equal(1, foundNodes.Count); + Assert.Equal("Server", foundNodes[0].NodeType); + Assert.Equal(session.RootNode.NodePath, foundNodes[0].ToNodeInfo().NodePath); + } + + [Fact] + public async Task FindNodesReturnsEmptyListForNoMatch() + { + var session = await CreateSession(); + + var foundNodes = service.FindNodes(session.SessionId, "Table", "testSchema", "testTable", "testDatabase"); + Assert.Equal(0, foundNodes.Count); + } + private async Task CreateSession() { SessionCreatedParameters sessionResult = null;