From 8f6662b0196fa348cbc458f5c87ddaf09d2f5b0a Mon Sep 17 00:00:00 2001 From: Alan Ren Date: Tue, 28 Apr 2020 16:12:56 -0700 Subject: [PATCH] extend the list databases request (#951) * extend the list databases request * reuse databaseInfo and refactor * fix typo --- .../Connection/ConnectionService.cs | 37 +-- .../Contracts/ListDatabasesParams.cs | 5 + .../Contracts/ListDatabasesResponse.cs | 7 + .../Connection/ListDatabaseRequestHandler.cs | 226 ++++++++++++++++++ .../Connection/ConnectionServiceTests.cs | 146 ++++++++--- .../Utility/TestDbDataReader.cs | 19 +- 6 files changed, 369 insertions(+), 71 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs index ce6e0c45..099b68cb 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ConnectionService.cs @@ -910,41 +910,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection { throw new Exception(SR.ConnectionServiceListDbErrorNotConnected(owner)); } - ConnectionDetails connectionDetails = info.ConnectionDetails.Clone(); - - // Connect to master and query sys.databases - connectionDetails.DatabaseName = "master"; - var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken); - connection.Open(); - - List results = new List(); - var systemDatabases = new[] { "master", "model", "msdb", "tempdb" }; - using (DbCommand command = connection.CreateCommand()) - { - command.CommandText = @"SELECT name FROM sys.databases WHERE state_desc='ONLINE' ORDER BY name ASC"; - command.CommandTimeout = 15; - command.CommandType = CommandType.Text; - - using (var reader = command.ExecuteReader()) - { - while (reader.Read()) - { - results.Add(reader[0].ToString()); - } - } - } - - // Put system databases at the top of the list - results = - results.Where(s => systemDatabases.Any(s.Equals)).Concat( - results.Where(s => systemDatabases.All(x => !s.Equals(x)))).ToList(); - - connection.Close(); - - ListDatabasesResponse response = new ListDatabasesResponse(); - response.DatabaseNames = results.ToArray(); - - return response; + var handler = ListDatabaseRequestHandlerFactory.getHandler(listDatabasesParams.IncludeDetails.HasTrue(), info.IsSqlDb); + return handler.HandleRequest(this.connectionFactory, info); } public void InitializeService(IProtocolEndpoint serviceHost) diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesParams.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesParams.cs index fa607e75..cef314c5 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesParams.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesParams.cs @@ -14,5 +14,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// URI of the owner of the connection requesting the list of databases. /// public string OwnerUri { get; set; } + + /// + /// whether to include the details of the databases. + /// + public bool? IncludeDetails { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesResponse.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesResponse.cs index 68610803..2880481d 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesResponse.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/Contracts/ListDatabasesResponse.cs @@ -3,6 +3,8 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; + namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts { /// @@ -14,5 +16,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts /// Gets or sets the list of database names. /// public string[] DatabaseNames { get; set; } + + /// + /// Gets or sets the databases details. + /// + public DatabaseInfo[] Databases { get; set; } } } diff --git a/src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs b/src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs new file mode 100644 index 00000000..006bcdf4 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Connection/ListDatabaseRequestHandler.cs @@ -0,0 +1,226 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Connection +{ + public class ListDatabasesRequestDatabaseProperties + { + public const string Name = "name"; + public const string SizeInMB = "sizeInMB"; + public const string State = "state"; + public const string LastBackup = "lastBackup"; + } + + /// + /// Factory class for ListDatabasesRequest handler + /// + static class ListDatabaseRequestHandlerFactory + { + public static IListDatabaseRequestHandler getHandler(bool includeDetails, bool isSqlDB) + { + if (!includeDetails) + { + return new DatabaseNamesHandler(); + } + else if (isSqlDB) + { + return new SqlDBDatabaseDetailHandler(); + } + else + { + return new SqlServerDatabaseDetailHandler(); + } + } + } + + /// + /// Interface of ListDatabasesRequest handler + /// + interface IListDatabaseRequestHandler + { + ListDatabasesResponse HandleRequest(ISqlConnectionFactory connectionFactory, ConnectionInfo connectionInfo); + } + + /// + /// Base handler + /// + abstract class ListDatabaseRequestHandler : IListDatabaseRequestHandler + { + private static readonly string[] SystemDatabases = new string[] { "master", "model", "msdb", "tempdb" }; + + public abstract string QueryText { get; } + + public ListDatabasesResponse HandleRequest(ISqlConnectionFactory connectionFactory, ConnectionInfo connectionInfo) + { + ConnectionDetails connectionDetails = connectionInfo.ConnectionDetails.Clone(); + + // Connect to master + connectionDetails.DatabaseName = "master"; + using (var connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken)) + { + connection.Open(); + ListDatabasesResponse response = new ListDatabasesResponse(); + using (DbCommand command = connection.CreateCommand()) + { + command.CommandText = this.QueryText; + command.CommandTimeout = 15; + command.CommandType = CommandType.Text; + using (var reader = command.ExecuteReader()) + { + List results = new List(); + while (reader.Read()) + { + results.Add(this.CreateItem(reader)); + } + // Put system databases at the top of the list + results = results.Where(s => SystemDatabases.Any(x => this.NameMatches(x, s))).Concat( + results.Where(s => SystemDatabases.All(x => !this.NameMatches(x, s)))).ToList(); + SetResponse(response, results.ToArray()); + } + } + connection.Close(); + return response; + } + } + + protected abstract bool NameMatches(string databaseName, T item); + protected abstract T CreateItem(DbDataReader reader); + protected abstract void SetResponse(ListDatabasesResponse response, T[] results); + } + + /// + /// database names handler + /// + class DatabaseNamesHandler : ListDatabaseRequestHandler + { + public override string QueryText + { + get + { + return @"SELECT name FROM sys.databases WHERE state_desc='ONLINE' ORDER BY name ASC"; + } + } + + protected override string CreateItem(DbDataReader reader) + { + return reader[0].ToString(); + } + + protected override bool NameMatches(string databaseName, string item) + { + return databaseName == item; + } + + protected override void SetResponse(ListDatabasesResponse response, string[] results) + { + response.DatabaseNames = results; + } + } + + + abstract class BaseDatabaseDetailHandler : ListDatabaseRequestHandler + { + protected override bool NameMatches(string databaseName, DatabaseInfo item) + { + return databaseName == item.Options[ListDatabasesRequestDatabaseProperties.Name].ToString(); + } + + protected override void SetResponse(ListDatabasesResponse response, DatabaseInfo[] results) + { + response.Databases = results; + } + + protected override DatabaseInfo CreateItem(DbDataReader reader) + { + DatabaseInfo databaseInfo = new DatabaseInfo(); + SetProperties(reader, databaseInfo); + return databaseInfo; + } + + protected virtual void SetProperties(DbDataReader reader, DatabaseInfo databaseInfo) + { + databaseInfo.Options[ListDatabasesRequestDatabaseProperties.Name] = reader["name"].ToString(); + databaseInfo.Options[ListDatabasesRequestDatabaseProperties.State] = reader["state"].ToString(); + databaseInfo.Options[ListDatabasesRequestDatabaseProperties.SizeInMB] = reader["size"].ToString(); + } + } + + /// + /// Standalone SQL Server database detail handler + /// + class SqlServerDatabaseDetailHandler : BaseDatabaseDetailHandler + { + public override string QueryText + { + get + { + return @" +WITH + db_size + AS + ( + SELECT database_id, CAST(SUM(size) * 8.0 / 1024 AS INTEGER) size + FROM sys.master_files + GROUP BY database_id + ), + db_backup + AS + ( + SELECT database_name, MAX(backup_start_date) AS last_backup + FROM msdb..backupset + GROUP BY database_name + ) +SELECT name, state_desc AS state, db_size.size, db_backup.last_backup +FROM sys.databases LEFT JOIN db_size ON sys.databases.database_id = db_size.database_id +LEFT JOIN db_backup ON sys.databases.name = db_backup.database_name +WHERE state_desc='ONLINE' +ORDER BY name ASC"; + } + } + + protected override void SetProperties(DbDataReader reader, DatabaseInfo databaseInfo) + { + base.SetProperties(reader, databaseInfo); + databaseInfo.Options[ListDatabasesRequestDatabaseProperties.LastBackup] = reader["last_backup"] == DBNull.Value ? "" : Convert.ToDateTime(reader["last_backup"]).ToString("yyyy-MM-dd hh:mm:ss"); + } + } + + /// + /// SQL DB database detail handler + /// + class SqlDBDatabaseDetailHandler : BaseDatabaseDetailHandler + { + public override string QueryText + { + get + { + return @" +WITH + db_size + AS + ( + SELECT name, storage_in_megabytes AS size + FROM ( +SELECT database_name name, max(end_time) size_time + FROM sys.resource_stats + GROUP BY database_name) db_size_time + LEFT JOIN sys.resource_stats ON database_name = name AND size_time = end_time + ) +SELECT db.name, state_desc AS state, size +FROM sys.databases db LEFT JOIN db_size ON db.name = db_size.name +WHERE state_desc='ONLINE' +ORDER BY name ASC +"; + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs index d099782c..dac4c9dd 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Connection/ConnectionServiceTests.cs @@ -20,6 +20,7 @@ using Moq; using Moq.Protected; using Xunit; using System.Linq; +using Microsoft.SqlTools.ServiceLayer.Admin.Contracts; namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection { @@ -144,7 +145,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection // Given a second connection that succeeds var mockConnection2 = new Mock { CallBase = true }; mockConnection2.Setup(x => x.OpenAsync(It.IsAny())) - .Returns(() => Task.Run(() => {})); + .Returns(() => Task.Run(() => { })); var mockFactory = new Mock(); mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) @@ -338,7 +339,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection // register disconnect callback connectionService.RegisterOnDisconnectTask( - (result, uri) => { + (result, uri) => + { callbackInvoked = true; Assert.True(uri.Equals(ownerUri)); return Task.FromResult(true); @@ -431,7 +433,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection .Connect(new ConnectParams() { OwnerUri = ownerUri, - Connection = new ConnectionDetails() { + Connection = new ConnectionDetails() + { ServerName = server, DatabaseName = database, UserName = userName, @@ -465,7 +468,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection .Connect(new ConnectParams() { OwnerUri = "file:///my/test/file.sql", - Connection = new ConnectionDetails() { + Connection = new ConnectionDetails() + { ServerName = "my-server", DatabaseName = "test", UserName = userName, @@ -761,7 +765,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection // register disconnect callback connectionService.RegisterOnDisconnectTask( - (result, uri) => { + (result, uri) => + { callbackInvoked = true; Assert.True(uri.Equals(ownerUri)); return Task.FromResult(true); @@ -853,28 +858,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.False(disconnectResult); } - /// - /// Verifies the the list databases operation lists database names for the server used by a connection. - /// - [Fact] - public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames() + private async Task RunListDatabasesRequestHandler(TestResultSet testdata, bool? includeDetails) { - // Result set for the query of database names - TestDbColumn[] cols = {new TestDbColumn("name")}; - object[][] rows = - { - new object[] {"master"}, - new object[] {"model"}, - new object[] {"msdb"}, - new object[] {"tempdb"}, - new object[] {"mydatabase"} - }; - TestResultSet data = new TestResultSet(cols, rows); - // Setup mock connection factory to inject query results var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) - .Returns(CreateMockDbConnection(new[] {data})); + .Returns(CreateMockDbConnection(new[] { testdata })); var connectionService = new ConnectionService(mockFactory.Object); // connect to a database instance @@ -891,9 +880,34 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.NotEmpty(connectionResult.ConnectionId); // list databases for the connection - ListDatabasesParams parameters = new ListDatabasesParams {OwnerUri = ownerUri}; - var listDatabasesResult = connectionService.ListDatabases(parameters); - string[] databaseNames = listDatabasesResult.DatabaseNames; + ListDatabasesParams parameters = new ListDatabasesParams + { + OwnerUri = ownerUri, + IncludeDetails = includeDetails + }; + return connectionService.ListDatabases(parameters); + } + + /// + /// Verifies the the list databases operation lists database names for the server used by a connection. + /// + [Fact] + public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames() + { + // Result set for the query of database names + TestDbColumn[] cols = { new TestDbColumn("name") }; + object[][] rows = + { + new object[] {"mydatabase"}, // this should be sorted to the end in the response + new object[] {"master"}, + new object[] {"model"}, + new object[] {"msdb"}, + new object[] {"tempdb"} + }; + TestResultSet data = new TestResultSet(cols, rows); + var response = await RunListDatabasesRequestHandler(testdata: data, includeDetails: null); + + string[] databaseNames = response.DatabaseNames; Assert.Equal(databaseNames.Length, 5); Assert.Equal(databaseNames[0], "master"); @@ -903,6 +917,79 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection Assert.Equal(databaseNames[4], "mydatabase"); } + /// + /// Verifies the the list databases operation lists database names for the server used by a connection. + /// + [Fact] + public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseDetails() + { + // Result set for the query of database names + TestDbColumn[] cols = { + new TestDbColumn("name"), + new TestDbColumn("state"), + new TestDbColumn("size"), + new TestDbColumn("last_backup") + }; + object[][] rows = + { + new object[] {"mydatabase", "Online", "10", "2010-01-01 11:11:11"}, // this should be sorted to the end in the response + new object[] {"master", "Online", "11", "2010-01-01 11:11:12"}, + new object[] {"model", "Offline", "12", "2010-01-01 11:11:13"}, + new object[] {"msdb", "Online", "13", "2010-01-01 11:11:14"}, + new object[] {"tempdb", "Online", "14", "2010-01-01 11:11:15"} + }; + TestResultSet data = new TestResultSet(cols, rows); + var response = await RunListDatabasesRequestHandler(testdata: data, includeDetails: true); + + Assert.Equal(response.Databases.Length, 5); + VerifyDatabaseDetail(rows[0], response.Databases[4]); + VerifyDatabaseDetail(rows[1], response.Databases[0]); + VerifyDatabaseDetail(rows[2], response.Databases[1]); + VerifyDatabaseDetail(rows[3], response.Databases[2]); + VerifyDatabaseDetail(rows[4], response.Databases[3]); + } + + private void VerifyDatabaseDetail(object[] expected, DatabaseInfo actual) + { + Assert.Equal(expected[0], actual.Options[ListDatabasesRequestDatabaseProperties.Name]); + Assert.Equal(expected[1], actual.Options[ListDatabasesRequestDatabaseProperties.State]); + Assert.Equal(expected[2], actual.Options[ListDatabasesRequestDatabaseProperties.SizeInMB]); + Assert.Equal(expected[3], actual.Options[ListDatabasesRequestDatabaseProperties.LastBackup]); + } + + + /// + /// Verify that the factory is returning DatabaseNamesHandler + /// + [Fact] + public void ListDatabaseRequestFactoryReturnsDatabaseNamesHandler() + { + var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: false, isSqlDB: true); + Assert.IsType(typeof(DatabaseNamesHandler), handler); + handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: false, isSqlDB: false); + Assert.IsType(typeof(DatabaseNamesHandler), handler); + } + + /// + /// Verify that the factory is returning SqlDBDatabaseDetailHandler + /// + [Fact] + public void ListDatabaseRequestFactoryReturnsSqlDBHandler() + { + var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: true, isSqlDB: true); + Assert.IsType(typeof(SqlDBDatabaseDetailHandler), handler); + } + + /// + /// Verify that the factory is returning SqlServerDatabaseDetailHandler + /// + [Fact] + public void ListDatabaseRequestFactoryReturnsSqlServerHandler() + { + var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: true, isSqlDB: false); + Assert.IsType(typeof(SqlServerDatabaseDetailHandler), handler); + } + /// /// Verify that the SQL parser correctly detects errors in text /// @@ -914,7 +1001,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection // setup connection service with callback var connectionService = TestObjects.GetTestConnectionService(); connectionService.RegisterOnConnectionTask( - (sqlConnection) => { + (sqlConnection) => + { callbackInvoked = true; return Task.FromResult(true); } @@ -1048,7 +1136,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection /// Test that the connection summary comparer creates a hash code correctly /// [Theory] - [InlineData(true, null, null ,null)] + [InlineData(true, null, null, null)] [InlineData(false, null, null, null)] [InlineData(false, null, null, "sa")] [InlineData(false, null, "test", null)] @@ -1586,7 +1674,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection [Fact] public async void ConnectingWithAzureAccountUsesToken() { - // Set up mock connection factory + // Set up mock connection factory var mockFactory = new Mock(); mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny(), It.IsAny())) .Returns(new TestSqlConnection(null)); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs index 660cd25f..1cc989cd 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Utility/TestDbDataReader.cs @@ -16,17 +16,17 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility public class TestDbException : DbException { } - + public class TestDbDataReader : DbDataReader, IDbColumnSchemaGenerator { #region Test Specific Implementations - + private IEnumerable Data { get; } private IEnumerator ResultSetEnumerator { get; } private IEnumerator RowEnumerator { get; set; } - + private bool ThrowOnRead { get; } public TestDbDataReader(IEnumerable data, bool throwOnRead) @@ -140,7 +140,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { if (ResultSetEnumerator.Current.Columns[ordinal].DataType == typeof(byte[])) { - byte[] data = (byte[]) this[ordinal]; + byte[] data = (byte[])this[ordinal]; if (buffer == null) { return data.Length; @@ -173,11 +173,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) { - char[] allChars = ((string) RowEnumerator.Current[ordinal]).ToCharArray(); + char[] allChars = ((string)RowEnumerator.Current[ordinal]).ToCharArray(); int outLength = allChars.Length; if (buffer != null) { - Array.Copy(allChars, (int) dataOffset, buffer, bufferOffset, outLength); + Array.Copy(allChars, (int)dataOffset, buffer, bufferOffset, outLength); } return outLength; } @@ -256,7 +256,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility { get { - throw new NotImplementedException(); + var column = ResultSetEnumerator?.Current.Columns.FindIndex(c => c.ColumnName == name); + if (!column.HasValue) + { + throw new ArgumentOutOfRangeException(); + } + return RowEnumerator.Current[column.Value]; } }