From 809254d2e2e62de6e0176235c33215c7f438cf8b Mon Sep 17 00:00:00 2001 From: Abbie Petchtes Date: Thu, 20 Apr 2017 12:33:49 -0700 Subject: [PATCH] Fix the OE service where returns database as the root node for database connection (#322) * Fix the OE service where retuns database as the root node if the connection contains database name * Fix OE tests * addressed the comments * addresses comment * fix OE test and add more tests * fix VerifyAdventureWorksDatabaseObjects test --- .../ObjectExplorer/ObjectExplorerService.cs | 16 ++++- .../ObjectExplorer/ObjectExplorerUtils.cs | 15 ++++ .../ObjectExplorer/SmoModel/ServerNode.cs | 55 ++++++++------- .../ObjectExplorerServiceTests.cs | 68 ++++++++++++++++--- .../ObjectExplorer/NodeTests.cs | 2 +- .../ObjectExplorerServiceTests.cs | 43 +++++++++++- 6 files changed, 159 insertions(+), 40 deletions(-) diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs index 66b97e34..99b0b385 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs @@ -346,8 +346,20 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer public static ObjectExplorerSession CreateSession(ConnectionCompleteParams response, IMultiServiceProvider serviceProvider) { - TreeNode serverNode = new ServerNode(response, serviceProvider); - return new ObjectExplorerSession(response.OwnerUri, serverNode, serviceProvider, serviceProvider.GetService()); + TreeNode rootNode = new ServerNode(response, serviceProvider); + var session = new ObjectExplorerSession(response.OwnerUri, rootNode, serviceProvider, serviceProvider.GetService()); + if (!ObjectExplorerUtils.IsSystemDatabaseConnection(response.ConnectionSummary.DatabaseName)) + { + // Assuming the databases are in a folder under server node + var children = rootNode.Expand(); + var databasesRoot = children.FirstOrDefault(x => x.NodeTypeId == NodeTypes.Databases); + var databasesChildren = databasesRoot.Expand(); + var databases = databasesChildren.Where(x => x.NodeType == NodeTypes.Database.ToString()); + var databaseNode = databases.FirstOrDefault(d => d.Label == response.ConnectionSummary.DatabaseName); + databaseNode.Label = rootNode.Label; + session.Root = databaseNode; + } + return session; } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs index 443c93a3..f646ffa4 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerUtils.cs @@ -5,6 +5,7 @@ using System; using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes; +using Microsoft.SqlTools.ServiceLayer.Utility; namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer { @@ -71,6 +72,20 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer } } return null; + } + + /// + /// Check if the database is a system database + /// + /// the name of database + /// return true if the database is a system database + public static bool IsSystemDatabaseConnection(string databaseName) + { + return (string.IsNullOrWhiteSpace(databaseName) || + string.Compare(databaseName, CommonConstants.MasterDatabaseName, StringComparison.OrdinalIgnoreCase) == 0 || + string.Compare(databaseName, CommonConstants.MsdbDatabaseName, StringComparison.OrdinalIgnoreCase) == 0 || + string.Compare(databaseName, CommonConstants.ModelDatabaseName, StringComparison.OrdinalIgnoreCase) == 0 || + string.Compare(databaseName, CommonConstants.TempDbDatabaseName, StringComparison.OrdinalIgnoreCase) == 0); } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/SmoModel/ServerNode.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/SmoModel/ServerNode.cs index 9ddbc948..0cf839ad 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/SmoModel/ServerNode.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/SmoModel/ServerNode.cs @@ -4,20 +4,20 @@ // using System; -using System.Data.Common; +using System.Data.Common; using System.Data.SqlClient; using System.Globalization; -using System.Linq; +using System.Linq; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlServer.Management.Smo; -using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.Extensibility; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; -using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; using Microsoft.SqlTools.ServiceLayer.ObjectExplorer.Nodes; using Microsoft.SqlTools.ServiceLayer.Utility; -using Microsoft.SqlTools.Utility; - +using Microsoft.SqlTools.Utility; + namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel { /// @@ -86,12 +86,17 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel // TODO Consider adding IsAuthenticatingDatabaseMaster check in the code and // referencing result here - if (!string.IsNullOrWhiteSpace(connectionSummary.DatabaseName) && - string.Compare(connectionSummary.DatabaseName, CommonConstants.MasterDatabaseName, StringComparison.OrdinalIgnoreCase) != 0 && - (serverInfo.IsCloud /* || !ci.IsAuthenticatingDatabaseMaster */)) + if (!ObjectExplorerUtils.IsSystemDatabaseConnection(connectionSummary.DatabaseName)) { // We either have an azure with a database specified or a Denali database using a contained user - userName += ", " + connectionSummary.DatabaseName; + if (string.IsNullOrWhiteSpace(userName)) + { + userName = connectionSummary.DatabaseName; + } + else + { + userName += ", " + connectionSummary.DatabaseName; + } } string label; @@ -124,7 +129,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel ConnectionInfo connectionInfo; SqlConnection connection = null; // Get server object from connection - if (!connectionService.TryFindConnection(this.connectionUri, out connectionInfo) || + if (!connectionService.TryFindConnection(this.connectionUri, out connectionInfo) || connectionInfo.AllConnections == null || connectionInfo.AllConnections.Count == 0) { ErrorStateMessage = string.Format(CultureInfo.CurrentCulture, @@ -135,19 +140,19 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel DbConnection dbConnection = connectionInfo.AllConnections.First(); ReliableSqlConnection reliableSqlConnection = dbConnection as ReliableSqlConnection; SqlConnection sqlConnection = dbConnection as SqlConnection; - if (reliableSqlConnection != null) - { - connection = reliableSqlConnection.GetUnderlyingConnection(); - } - else if (sqlConnection != null) - { - connection = sqlConnection; - } - else - { - ErrorStateMessage = string.Format(CultureInfo.CurrentCulture, + if (reliableSqlConnection != null) + { + connection = reliableSqlConnection.GetUnderlyingConnection(); + } + else if (sqlConnection != null) + { + connection = sqlConnection; + } + else + { + ErrorStateMessage = string.Format(CultureInfo.CurrentCulture, SR.ServerNodeConnectionError, connectionSummary.ServerName); - return null; + return null; } try @@ -170,8 +175,8 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer.SmoModel Logger.Write(LogLevel.Error, "Exception at ServerNode.CreateContext() : " + exceptionMessage); this.ErrorStateMessage = string.Format(SR.TreeNodeError, exceptionMessage); return null; - } - + } + public override object GetContext() { return context.Value; diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/ObjectExplorerServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/ObjectExplorerServiceTests.cs index 2d9f0e45..9f0e6795 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/ObjectExplorerServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ObjectExplorer/ObjectExplorerServiceTests.cs @@ -23,15 +23,43 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer private ObjectExplorerService _service = TestServiceProvider.Instance.ObjectExplorerService; [Fact] - public async void CreateSessionAndExpandOnTheServerShouldReturnTheDatabases() + public async void CreateSessionAndExpandOnTheServerShouldReturnServerAsTheRoot() { var query = ""; - string uri = "CreateSessionAndExpand"; + string uri = "CreateSessionAndExpandServer"; + string databaseName = null; + using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) + { + var session = await CreateSession(null, uri); + await ExpandServerNodeAndVerifyDatabaseHierachy(testDb.DatabaseName, session); + CancelConnection(uri); + } + } + + [Fact] + public async void CreateSessionWithTempdbAndExpandOnTheServerShouldReturnServerAsTheRoot() + { + var query = ""; + string uri = "CreateSessionAndExpandServer"; + string databaseName = null; + using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) + { + var session = await CreateSession("tempdb", uri); + await ExpandServerNodeAndVerifyDatabaseHierachy(testDb.DatabaseName, session); + CancelConnection(uri); + } + } + + [Fact] + public async void CreateSessionAndExpandOnTheDatabaseShouldReturnDatabaseAsTheRoot() + { + var query = ""; + string uri = "CreateSessionAndExpandDatabase"; string databaseName = null; using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) { var session = await CreateSession(testDb.DatabaseName, uri); - await CreateSessionAndDatabaseNode(testDb.DatabaseName, session); + ExpandAndVerifyDatabaseNode(testDb.DatabaseName, session); CancelConnection(uri); } } @@ -44,7 +72,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer return await _service.DoCreateSession(details, uri); } - private async Task CreateSessionAndDatabaseNode(string databaseName, ObjectExplorerSession session) + private async Task ExpandServerNodeAndVerifyDatabaseHierachy(string databaseName, ObjectExplorerSession session) { Assert.NotNull(session); Assert.NotNull(session.Root); @@ -70,6 +98,26 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer return databaseNode; } + private void ExpandAndVerifyDatabaseNode(string databaseName, ObjectExplorerSession session) + { + Assert.NotNull(session); + Assert.NotNull(session.Root); + NodeInfo nodeInfo = session.Root.ToNodeInfo(); + Assert.Equal(nodeInfo.IsLeaf, false); + Assert.Equal(nodeInfo.NodeType, NodeTypes.Database.ToString()); + Assert.True(nodeInfo.Label.Contains(databaseName)); + var children = session.Root.Expand(); + + //All server children should be folder nodes + foreach (var item in children) + { + Assert.Equal(item.NodeType, "Folder"); + } + + var tablesRoot = children.FirstOrDefault(x => x.NodeTypeId == NodeTypes.Tables); + Assert.NotNull(tablesRoot); + } + private void CancelConnection(string uri) { //ConnectionService.Instance.CancelConnect(new CancelConnectParams @@ -129,8 +177,8 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) { - var session = await CreateSession(testDb.DatabaseName, uri); - var databaseNodeInfo = await CreateSessionAndDatabaseNode(testDb.DatabaseName, session); + var session = await CreateSession(null, uri); + var databaseNodeInfo = await ExpandServerNodeAndVerifyDatabaseHierachy(testDb.DatabaseName, session); await ExpandTree(databaseNodeInfo, session); CancelConnection(uri); } @@ -146,7 +194,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) { var session = await CreateSession(testDb.DatabaseName, uri); - var databaseNodeInfo = await CreateSessionAndDatabaseNode(testDb.DatabaseName, session); + var databaseNodeInfo = await ExpandServerNodeAndVerifyDatabaseHierachy(testDb.DatabaseName, session); await ExpandTree(databaseNodeInfo, session); CancelConnection(uri); } @@ -162,7 +210,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) { var session = await CreateSession(testDb.DatabaseName, uri); - var databaseNodeInfo = await CreateSessionAndDatabaseNode(testDb.DatabaseName, session); + var databaseNodeInfo = await ExpandServerNodeAndVerifyDatabaseHierachy(testDb.DatabaseName, session); await ExpandTree(databaseNodeInfo, session); CancelConnection(uri); } @@ -178,7 +226,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) { var session = await CreateSession(testDb.DatabaseName, uri); - var databaseNodeInfo = await CreateSessionAndDatabaseNode(testDb.DatabaseName, session); + var databaseNodeInfo = await ExpandServerNodeAndVerifyDatabaseHierachy(testDb.DatabaseName, session); await ExpandTree(databaseNodeInfo, session); CancelConnection(uri); } @@ -194,7 +242,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ObjectExplorer using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName, query, uri)) { var session = await CreateSession(testDb.DatabaseName, uri); - var databaseNodeInfo = await CreateSessionAndDatabaseNode(testDb.DatabaseName, session); + var databaseNodeInfo = await ExpandServerNodeAndVerifyDatabaseHierachy(testDb.DatabaseName, session); await ExpandTree(databaseNodeInfo, session); CancelConnection(uri); } diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/NodeTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/NodeTests.cs index da1870b1..ce7fa73d 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/NodeTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/NodeTests.cs @@ -347,7 +347,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer TreeNode databases = children[0]; IList dbChildren = databases.Expand(); Assert.Equal(2, dbChildren.Count); - Assert.Equal("System Databases", dbChildren[0].NodeValue); + Assert.Equal(SR.SchemaHierarchy_SystemDatabases, dbChildren[0].NodeValue); TreeNode dbNode = dbChildren[1]; Assert.Equal(dbName, dbNode.NodeValue); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs index 1c6adb77..74d32163 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ObjectExplorer/ObjectExplorerServiceTests.cs @@ -75,10 +75,49 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.ObjectExplorer [Fact] - public async Task CreateSessionRequestReturnsSuccessAndNodeInfo() + public async Task CreateSessionRequestWithMasterConnectionReturnsServerSuccessAndNodeInfo() { // Given the connection service fails to connect - ConnectionDetails details = TestObjects.GetTestConnectionDetails(); + ConnectionDetails details = new ConnectionDetails() + { + UserName = "user", + Password = "password", + DatabaseName = "master", + ServerName = "serverName" + }; + await CreateSessionRequestAndVerifyServerNodeHelper(details); + } + + [Fact] + public async Task CreateSessionRequestWithEmptyConnectionReturnsServerSuccessAndNodeInfo() + { + // Given the connection service fails to connect + ConnectionDetails details = new ConnectionDetails() + { + UserName = "user", + Password = "password", + DatabaseName = "", + ServerName = "serverName" + }; + await CreateSessionRequestAndVerifyServerNodeHelper(details); + } + + [Fact] + public async Task CreateSessionRequestWithMsdbConnectionReturnsServerSuccessAndNodeInfo() + { + // Given the connection service fails to connect + ConnectionDetails details = new ConnectionDetails() + { + UserName = "user", + Password = "password", + DatabaseName = "msdb", + ServerName = "serverName" + }; + await CreateSessionRequestAndVerifyServerNodeHelper(details); + } + + private async Task CreateSessionRequestAndVerifyServerNodeHelper(ConnectionDetails details) + { serviceHostMock.AddEventHandling(ConnectionCompleteNotification.Type, null); connectionServiceMock.Setup(c => c.Connect(It.IsAny()))