extend the list databases request (#951)

* extend the list databases request

* reuse databaseInfo and refactor

* fix typo
This commit is contained in:
Alan Ren
2020-04-28 16:12:56 -07:00
committed by GitHub
parent 96df91c8fa
commit 8f6662b019
6 changed files with 369 additions and 71 deletions

View File

@@ -910,41 +910,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection
{
throw new Exception(SR.ConnectionServiceListDbErrorNotConnected(owner));
}
ConnectionDetails connectionDetails = info.ConnectionDetails.Clone();
// Connect to master and query sys.databases
connectionDetails.DatabaseName = "master";
var connection = this.ConnectionFactory.CreateSqlConnection(BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken);
connection.Open();
List<string> results = new List<string>();
var systemDatabases = new[] { "master", "model", "msdb", "tempdb" };
using (DbCommand command = connection.CreateCommand())
{
command.CommandText = @"SELECT name FROM sys.databases WHERE state_desc='ONLINE' ORDER BY name ASC";
command.CommandTimeout = 15;
command.CommandType = CommandType.Text;
using (var reader = command.ExecuteReader())
{
while (reader.Read())
{
results.Add(reader[0].ToString());
}
}
}
// Put system databases at the top of the list
results =
results.Where(s => systemDatabases.Any(s.Equals)).Concat(
results.Where(s => systemDatabases.All(x => !s.Equals(x)))).ToList();
connection.Close();
ListDatabasesResponse response = new ListDatabasesResponse();
response.DatabaseNames = results.ToArray();
return response;
var handler = ListDatabaseRequestHandlerFactory.getHandler(listDatabasesParams.IncludeDetails.HasTrue(), info.IsSqlDb);
return handler.HandleRequest(this.connectionFactory, info);
}
public void InitializeService(IProtocolEndpoint serviceHost)

View File

@@ -14,5 +14,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// URI of the owner of the connection requesting the list of databases.
/// </summary>
public string OwnerUri { get; set; }
/// <summary>
/// whether to include the details of the databases.
/// </summary>
public bool? IncludeDetails { get; set; }
}
}

View File

@@ -3,6 +3,8 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using Microsoft.SqlTools.ServiceLayer.Admin.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
{
/// <summary>
@@ -14,5 +16,10 @@ namespace Microsoft.SqlTools.ServiceLayer.Connection.Contracts
/// Gets or sets the list of database names.
/// </summary>
public string[] DatabaseNames { get; set; }
/// <summary>
/// Gets or sets the databases details.
/// </summary>
public DatabaseInfo[] Databases { get; set; }
}
}

View File

@@ -0,0 +1,226 @@
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Linq;
using Microsoft.SqlTools.ServiceLayer.Admin.Contracts;
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.Connection
{
public class ListDatabasesRequestDatabaseProperties
{
public const string Name = "name";
public const string SizeInMB = "sizeInMB";
public const string State = "state";
public const string LastBackup = "lastBackup";
}
/// <summary>
/// Factory class for ListDatabasesRequest handler
/// </summary>
static class ListDatabaseRequestHandlerFactory
{
public static IListDatabaseRequestHandler getHandler(bool includeDetails, bool isSqlDB)
{
if (!includeDetails)
{
return new DatabaseNamesHandler();
}
else if (isSqlDB)
{
return new SqlDBDatabaseDetailHandler();
}
else
{
return new SqlServerDatabaseDetailHandler();
}
}
}
/// <summary>
/// Interface of ListDatabasesRequest handler
/// </summary>
interface IListDatabaseRequestHandler
{
ListDatabasesResponse HandleRequest(ISqlConnectionFactory connectionFactory, ConnectionInfo connectionInfo);
}
/// <summary>
/// Base handler
/// </summary>
abstract class ListDatabaseRequestHandler<T> : IListDatabaseRequestHandler
{
private static readonly string[] SystemDatabases = new string[] { "master", "model", "msdb", "tempdb" };
public abstract string QueryText { get; }
public ListDatabasesResponse HandleRequest(ISqlConnectionFactory connectionFactory, ConnectionInfo connectionInfo)
{
ConnectionDetails connectionDetails = connectionInfo.ConnectionDetails.Clone();
// Connect to master
connectionDetails.DatabaseName = "master";
using (var connection = connectionFactory.CreateSqlConnection(ConnectionService.BuildConnectionString(connectionDetails), connectionDetails.AzureAccountToken))
{
connection.Open();
ListDatabasesResponse response = new ListDatabasesResponse();
using (DbCommand command = connection.CreateCommand())
{
command.CommandText = this.QueryText;
command.CommandTimeout = 15;
command.CommandType = CommandType.Text;
using (var reader = command.ExecuteReader())
{
List<T> results = new List<T>();
while (reader.Read())
{
results.Add(this.CreateItem(reader));
}
// Put system databases at the top of the list
results = results.Where(s => SystemDatabases.Any(x => this.NameMatches(x, s))).Concat(
results.Where(s => SystemDatabases.All(x => !this.NameMatches(x, s)))).ToList();
SetResponse(response, results.ToArray());
}
}
connection.Close();
return response;
}
}
protected abstract bool NameMatches(string databaseName, T item);
protected abstract T CreateItem(DbDataReader reader);
protected abstract void SetResponse(ListDatabasesResponse response, T[] results);
}
/// <summary>
/// database names handler
/// </summary>
class DatabaseNamesHandler : ListDatabaseRequestHandler<string>
{
public override string QueryText
{
get
{
return @"SELECT name FROM sys.databases WHERE state_desc='ONLINE' ORDER BY name ASC";
}
}
protected override string CreateItem(DbDataReader reader)
{
return reader[0].ToString();
}
protected override bool NameMatches(string databaseName, string item)
{
return databaseName == item;
}
protected override void SetResponse(ListDatabasesResponse response, string[] results)
{
response.DatabaseNames = results;
}
}
abstract class BaseDatabaseDetailHandler : ListDatabaseRequestHandler<DatabaseInfo>
{
protected override bool NameMatches(string databaseName, DatabaseInfo item)
{
return databaseName == item.Options[ListDatabasesRequestDatabaseProperties.Name].ToString();
}
protected override void SetResponse(ListDatabasesResponse response, DatabaseInfo[] results)
{
response.Databases = results;
}
protected override DatabaseInfo CreateItem(DbDataReader reader)
{
DatabaseInfo databaseInfo = new DatabaseInfo();
SetProperties(reader, databaseInfo);
return databaseInfo;
}
protected virtual void SetProperties(DbDataReader reader, DatabaseInfo databaseInfo)
{
databaseInfo.Options[ListDatabasesRequestDatabaseProperties.Name] = reader["name"].ToString();
databaseInfo.Options[ListDatabasesRequestDatabaseProperties.State] = reader["state"].ToString();
databaseInfo.Options[ListDatabasesRequestDatabaseProperties.SizeInMB] = reader["size"].ToString();
}
}
/// <summary>
/// Standalone SQL Server database detail handler
/// </summary>
class SqlServerDatabaseDetailHandler : BaseDatabaseDetailHandler
{
public override string QueryText
{
get
{
return @"
WITH
db_size
AS
(
SELECT database_id, CAST(SUM(size) * 8.0 / 1024 AS INTEGER) size
FROM sys.master_files
GROUP BY database_id
),
db_backup
AS
(
SELECT database_name, MAX(backup_start_date) AS last_backup
FROM msdb..backupset
GROUP BY database_name
)
SELECT name, state_desc AS state, db_size.size, db_backup.last_backup
FROM sys.databases LEFT JOIN db_size ON sys.databases.database_id = db_size.database_id
LEFT JOIN db_backup ON sys.databases.name = db_backup.database_name
WHERE state_desc='ONLINE'
ORDER BY name ASC";
}
}
protected override void SetProperties(DbDataReader reader, DatabaseInfo databaseInfo)
{
base.SetProperties(reader, databaseInfo);
databaseInfo.Options[ListDatabasesRequestDatabaseProperties.LastBackup] = reader["last_backup"] == DBNull.Value ? "" : Convert.ToDateTime(reader["last_backup"]).ToString("yyyy-MM-dd hh:mm:ss");
}
}
/// <summary>
/// SQL DB database detail handler
/// </summary>
class SqlDBDatabaseDetailHandler : BaseDatabaseDetailHandler
{
public override string QueryText
{
get
{
return @"
WITH
db_size
AS
(
SELECT name, storage_in_megabytes AS size
FROM (
SELECT database_name name, max(end_time) size_time
FROM sys.resource_stats
GROUP BY database_name) db_size_time
LEFT JOIN sys.resource_stats ON database_name = name AND size_time = end_time
)
SELECT db.name, state_desc AS state, size
FROM sys.databases db LEFT JOIN db_size ON db.name = db_size.name
WHERE state_desc='ONLINE'
ORDER BY name ASC
";
}
}
}
}

View File

@@ -20,6 +20,7 @@ using Moq;
using Moq.Protected;
using Xunit;
using System.Linq;
using Microsoft.SqlTools.ServiceLayer.Admin.Contracts;
namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
{
@@ -144,7 +145,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// Given a second connection that succeeds
var mockConnection2 = new Mock<DbConnection> { CallBase = true };
mockConnection2.Setup(x => x.OpenAsync(It.IsAny<CancellationToken>()))
.Returns(() => Task.Run(() => {}));
.Returns(() => Task.Run(() => { }));
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.SetupSequence(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
@@ -338,7 +339,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// register disconnect callback
connectionService.RegisterOnDisconnectTask(
(result, uri) => {
(result, uri) =>
{
callbackInvoked = true;
Assert.True(uri.Equals(ownerUri));
return Task.FromResult(true);
@@ -431,7 +433,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
.Connect(new ConnectParams()
{
OwnerUri = ownerUri,
Connection = new ConnectionDetails() {
Connection = new ConnectionDetails()
{
ServerName = server,
DatabaseName = database,
UserName = userName,
@@ -465,7 +468,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
.Connect(new ConnectParams()
{
OwnerUri = "file:///my/test/file.sql",
Connection = new ConnectionDetails() {
Connection = new ConnectionDetails()
{
ServerName = "my-server",
DatabaseName = "test",
UserName = userName,
@@ -761,7 +765,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// register disconnect callback
connectionService.RegisterOnDisconnectTask(
(result, uri) => {
(result, uri) =>
{
callbackInvoked = true;
Assert.True(uri.Equals(ownerUri));
return Task.FromResult(true);
@@ -853,28 +858,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.False(disconnectResult);
}
/// <summary>
/// Verifies the the list databases operation lists database names for the server used by a connection.
/// </summary>
[Fact]
public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames()
private async Task<ListDatabasesResponse> RunListDatabasesRequestHandler(TestResultSet testdata, bool? includeDetails)
{
// Result set for the query of database names
TestDbColumn[] cols = {new TestDbColumn("name")};
object[][] rows =
{
new object[] {"master"},
new object[] {"model"},
new object[] {"msdb"},
new object[] {"tempdb"},
new object[] {"mydatabase"}
};
TestResultSet data = new TestResultSet(cols, rows);
// Setup mock connection factory to inject query results
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(CreateMockDbConnection(new[] {data}));
.Returns(CreateMockDbConnection(new[] { testdata }));
var connectionService = new ConnectionService(mockFactory.Object);
// connect to a database instance
@@ -891,9 +880,34 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.NotEmpty(connectionResult.ConnectionId);
// list databases for the connection
ListDatabasesParams parameters = new ListDatabasesParams {OwnerUri = ownerUri};
var listDatabasesResult = connectionService.ListDatabases(parameters);
string[] databaseNames = listDatabasesResult.DatabaseNames;
ListDatabasesParams parameters = new ListDatabasesParams
{
OwnerUri = ownerUri,
IncludeDetails = includeDetails
};
return connectionService.ListDatabases(parameters);
}
/// <summary>
/// Verifies the the list databases operation lists database names for the server used by a connection.
/// </summary>
[Fact]
public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseNames()
{
// Result set for the query of database names
TestDbColumn[] cols = { new TestDbColumn("name") };
object[][] rows =
{
new object[] {"mydatabase"}, // this should be sorted to the end in the response
new object[] {"master"},
new object[] {"model"},
new object[] {"msdb"},
new object[] {"tempdb"}
};
TestResultSet data = new TestResultSet(cols, rows);
var response = await RunListDatabasesRequestHandler(testdata: data, includeDetails: null);
string[] databaseNames = response.DatabaseNames;
Assert.Equal(databaseNames.Length, 5);
Assert.Equal(databaseNames[0], "master");
@@ -903,6 +917,79 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
Assert.Equal(databaseNames[4], "mydatabase");
}
/// <summary>
/// Verifies the the list databases operation lists database names for the server used by a connection.
/// </summary>
[Fact]
public async Task ListDatabasesOnServerForCurrentConnectionReturnsDatabaseDetails()
{
// Result set for the query of database names
TestDbColumn[] cols = {
new TestDbColumn("name"),
new TestDbColumn("state"),
new TestDbColumn("size"),
new TestDbColumn("last_backup")
};
object[][] rows =
{
new object[] {"mydatabase", "Online", "10", "2010-01-01 11:11:11"}, // this should be sorted to the end in the response
new object[] {"master", "Online", "11", "2010-01-01 11:11:12"},
new object[] {"model", "Offline", "12", "2010-01-01 11:11:13"},
new object[] {"msdb", "Online", "13", "2010-01-01 11:11:14"},
new object[] {"tempdb", "Online", "14", "2010-01-01 11:11:15"}
};
TestResultSet data = new TestResultSet(cols, rows);
var response = await RunListDatabasesRequestHandler(testdata: data, includeDetails: true);
Assert.Equal(response.Databases.Length, 5);
VerifyDatabaseDetail(rows[0], response.Databases[4]);
VerifyDatabaseDetail(rows[1], response.Databases[0]);
VerifyDatabaseDetail(rows[2], response.Databases[1]);
VerifyDatabaseDetail(rows[3], response.Databases[2]);
VerifyDatabaseDetail(rows[4], response.Databases[3]);
}
private void VerifyDatabaseDetail(object[] expected, DatabaseInfo actual)
{
Assert.Equal(expected[0], actual.Options[ListDatabasesRequestDatabaseProperties.Name]);
Assert.Equal(expected[1], actual.Options[ListDatabasesRequestDatabaseProperties.State]);
Assert.Equal(expected[2], actual.Options[ListDatabasesRequestDatabaseProperties.SizeInMB]);
Assert.Equal(expected[3], actual.Options[ListDatabasesRequestDatabaseProperties.LastBackup]);
}
/// <summary>
/// Verify that the factory is returning DatabaseNamesHandler
/// </summary>
[Fact]
public void ListDatabaseRequestFactoryReturnsDatabaseNamesHandler()
{
var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: false, isSqlDB: true);
Assert.IsType(typeof(DatabaseNamesHandler), handler);
handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: false, isSqlDB: false);
Assert.IsType(typeof(DatabaseNamesHandler), handler);
}
/// <summary>
/// Verify that the factory is returning SqlDBDatabaseDetailHandler
/// </summary>
[Fact]
public void ListDatabaseRequestFactoryReturnsSqlDBHandler()
{
var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: true, isSqlDB: true);
Assert.IsType(typeof(SqlDBDatabaseDetailHandler), handler);
}
/// <summary>
/// Verify that the factory is returning SqlServerDatabaseDetailHandler
/// </summary>
[Fact]
public void ListDatabaseRequestFactoryReturnsSqlServerHandler()
{
var handler = ListDatabaseRequestHandlerFactory.getHandler(includeDetails: true, isSqlDB: false);
Assert.IsType(typeof(SqlServerDatabaseDetailHandler), handler);
}
/// <summary>
/// Verify that the SQL parser correctly detects errors in text
/// </summary>
@@ -914,7 +1001,8 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
// setup connection service with callback
var connectionService = TestObjects.GetTestConnectionService();
connectionService.RegisterOnConnectionTask(
(sqlConnection) => {
(sqlConnection) =>
{
callbackInvoked = true;
return Task.FromResult(true);
}
@@ -1048,7 +1136,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
/// Test that the connection summary comparer creates a hash code correctly
/// </summary>
[Theory]
[InlineData(true, null, null ,null)]
[InlineData(true, null, null, null)]
[InlineData(false, null, null, null)]
[InlineData(false, null, null, "sa")]
[InlineData(false, null, "test", null)]
@@ -1586,7 +1674,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Connection
[Fact]
public async void ConnectingWithAzureAccountUsesToken()
{
// Set up mock connection factory
// Set up mock connection factory
var mockFactory = new Mock<ISqlConnectionFactory>();
mockFactory.Setup(factory => factory.CreateSqlConnection(It.IsAny<string>(), It.IsAny<string>()))
.Returns(new TestSqlConnection(null));

View File

@@ -16,17 +16,17 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
public class TestDbException : DbException
{
}
public class TestDbDataReader : DbDataReader, IDbColumnSchemaGenerator
{
#region Test Specific Implementations
private IEnumerable<TestResultSet> Data { get; }
private IEnumerator<TestResultSet> ResultSetEnumerator { get; }
private IEnumerator<object[]> RowEnumerator { get; set; }
private bool ThrowOnRead { get; }
public TestDbDataReader(IEnumerable<TestResultSet> data, bool throwOnRead)
@@ -140,7 +140,7 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
{
if (ResultSetEnumerator.Current.Columns[ordinal].DataType == typeof(byte[]))
{
byte[] data = (byte[]) this[ordinal];
byte[] data = (byte[])this[ordinal];
if (buffer == null)
{
return data.Length;
@@ -173,11 +173,11 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length)
{
char[] allChars = ((string) RowEnumerator.Current[ordinal]).ToCharArray();
char[] allChars = ((string)RowEnumerator.Current[ordinal]).ToCharArray();
int outLength = allChars.Length;
if (buffer != null)
{
Array.Copy(allChars, (int) dataOffset, buffer, bufferOffset, outLength);
Array.Copy(allChars, (int)dataOffset, buffer, bufferOffset, outLength);
}
return outLength;
}
@@ -256,7 +256,12 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Utility
{
get
{
throw new NotImplementedException();
var column = ResultSetEnumerator?.Current.Columns.FindIndex(c => c.ColumnName == name);
if (!column.HasValue)
{
throw new ArgumentOutOfRangeException();
}
return RowEnumerator.Current[column.Value];
}
}