diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs index 05eeba8f..f2e9a0b0 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs @@ -8,10 +8,13 @@ using System; using System.Collections.Generic; using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Linq; using Microsoft.SqlServer.Management.Common; using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; +using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.Connection { @@ -64,14 +67,50 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection public ListDatabasesResponse HandleRequest(ISqlConnectionFactory connectionFactory, ConnectionInfo connectionInfo) { + ListDatabasesResponse response = new ListDatabasesResponse(); ConnectionDetails connectionDetails = connectionInfo.ConnectionDetails.Clone(); + // Running query against sys.databases view will only return a subset of databases the current login/user might have access to, we need to + // query the master database to get the full database list, but for users without master db access, we have to query the + // original database as a fallback. + var databasesToTry = new List() { CommonConstants.MasterDatabaseName }; + if (connectionDetails.DatabaseName != CommonConstants.MasterDatabaseName) + { + databasesToTry.Add(connectionDetails.DatabaseName); + } + for (int i = 0; i < databasesToTry.Count; i++) + { + try + { + connectionDetails.DatabaseName = databasesToTry[i]; + var results = this.GetResults(connectionFactory, connectionDetails); + SetResponse(response, results); + break; + } + catch (Microsoft.Data.SqlClient.SqlException ex) + { + // Retry when login attempt failed. + // https://learn.microsoft.com/sql/relational-databases/errors-events/mssqlserver-18456-database-engine-error + if (i != databasesToTry.Count - 1 && ex.Number == 18456) + { + Logger.Write(TraceEventType.Information, string.Format("Failed to get database list from database '{0}', will fallback to original database.", databasesToTry[i])); + continue; + } + else + { + throw; + } + } - // Connect to master - connectionDetails.DatabaseName = "master"; + } + return response; + } + + private T[] GetResults(ISqlConnectionFactory connectionFactory, ConnectionDetails connectionDetails) + { + List results = new List(); using (var connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken)) { connection.Open(); - ListDatabasesResponse response = new ListDatabasesResponse(); using (DbCommand command = connection.CreateCommand()) { command.CommandText = this.QueryText; @@ -79,7 +118,6 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection command.CommandType = CommandType.Text; using (var reader = command.ExecuteReader()) { - List results = new List(); while (reader.Read()) { results.Add(this.CreateItem(reader)); @@ -87,16 +125,17 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection // Put system databases at the top of the list results = results.Where(s => SystemDatabases.Any(x => this.NameMatches(x, s))).Concat( results.Where(s => SystemDatabases.All(x => !this.NameMatches(x, s)))).ToList(); - SetResponse(response, results.ToArray()); } } connection.Close(); - return response; } + return results.ToArray(); } protected abstract bool NameMatches(string databaseName, T item); + protected abstract T CreateItem(DbDataReader reader); + protected abstract void SetResponse(ListDatabasesResponse response, T[] results); }