diff --git a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs index b790a5dc..2bd7d06c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/ObjectExplorer/ObjectExplorerService.cs @@ -458,6 +458,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer ObjectExplorerSession session = null; connectionDetails.PersistSecurityInfo = true; ConnectParams connectParams = new ConnectParams() { OwnerUri = uri, Connection = connectionDetails, Type = Connection.ConnectionType.ObjectExplorer }; + string connectionDatabase = connectionDetails.DatabaseName; ConnectionInfo connectionInfo; ConnectionCompleteParams connectionResult = await Connect(connectParams, uri); @@ -479,7 +480,7 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer waitForLockTimeout: timeout, bindOperation: (bindingContext, cancelToken) => { - session = ObjectExplorerSession.CreateSession(connectionResult, serviceProvider, bindingContext.ServerConnection); + session = ObjectExplorerSession.CreateSession(connectionResult, serviceProvider, bindingContext.ServerConnection, connectionDatabase); session.ConnectionInfo = connectionInfo; sessionMap.AddOrUpdate(uri, session, (key, oldSession) => session); @@ -728,11 +729,11 @@ namespace Microsoft.SqlTools.ServiceLayer.ObjectExplorer public string ErrorMessage { get; set; } - public static ObjectExplorerSession CreateSession(ConnectionCompleteParams response, IMultiServiceProvider serviceProvider, ServerConnection serverConnection) + public static ObjectExplorerSession CreateSession(ConnectionCompleteParams response, IMultiServiceProvider serviceProvider, ServerConnection serverConnection, string connectionDatabase) { ServerNode rootNode = new ServerNode(response, serviceProvider, serverConnection); var session = new ObjectExplorerSession(response.OwnerUri, rootNode, serviceProvider, serviceProvider.GetService()); - if (!DatabaseUtils.IsSystemDatabaseConnection(response.ConnectionSummary.DatabaseName)) + if (!DatabaseUtils.IsSystemDatabaseConnection(connectionDatabase)) { // Assuming the databases are in a folder under server node DatabaseTreeNode databaseNode = new DatabaseTreeNode(rootNode, response.ConnectionSummary.DatabaseName);