diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs index 7e4a76e4..bc65462c 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs @@ -61,7 +61,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata /// /// Handle a metadata query request /// - internal static async Task HandleMetadataListRequest( + internal async Task HandleMetadataListRequest( MetadataQueryParams metadataParams, RequestContext requestContext) { @@ -93,6 +93,7 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata { await requestContext.SendError(t.Exception.ToString()); }); + MetadataListTask = task; } catch (Exception ex) { @@ -100,6 +101,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata } } + internal Task MetadataListTask { get; set; } + /// /// Handle a table metadata query request /// @@ -169,11 +172,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata /// internal static void ReadMetadata(SqlConnection sqlConn, List metadata) { - string sql = + string sql = @"SELECT s.name AS schema_name, o.[name] AS object_name, o.[type] AS object_type FROM sys.all_objects o INNER JOIN sys.schemas s ON o.schema_id = s.schema_id - WHERE (o.[type] = 'P' OR o.[type] = 'V' OR o.[type] = 'U') "; + WHERE (o.[type] = 'P' OR o.[type] = 'V' OR o.[type] = 'U' OR o.[type] = 'AF' OR o.[type] = 'FN' OR o.[type] = 'IF') "; if (!IsSystemDatabase(sqlConn.Database)) { @@ -204,6 +207,11 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata metadataType = MetadataType.SProc; metadataTypeName = "StoredProcedure"; } + else if (objectType == "AF" || objectType == "FN" || objectType == "IF") + { + metadataType = MetadataType.Function; + metadataTypeName = "UserDefinedFunction"; + } else { metadataType = MetadataType.Table; diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs index 295de177..14c2a469 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs @@ -3,18 +3,22 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // -using System; -using System.Collections.Generic; -using System.Data.SqlClient; -using System.Threading.Tasks; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; using Microsoft.SqlTools.ServiceLayer.Metadata; using Microsoft.SqlTools.ServiceLayer.Metadata.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Common; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; using Moq; +using System; +using System.Collections.Generic; +using System.Data.SqlClient; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; using Xunit; +using static Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility.LiveConnectionHelper; namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata { @@ -142,5 +146,124 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata requestContext.VerifyAll(); } + [Fact] + public async void VerifyMetadataList() + { + string query = @"CREATE TABLE testTable1 (c1 int) + GO + CREATE PROCEDURE testSp1 @StartProductID [int] AS BEGIN Select * from sys.all_columns END + GO + CREATE VIEW testView1 AS SELECT * from sys.all_columns + GO + CREATE FUNCTION testFun1() RETURNS [int] AS BEGIN RETURN 1 END + GO + CREATE FUNCTION [testFun2](@CityID int) + RETURNS TABLE + WITH SCHEMABINDING + AS + RETURN SELECT 1 AS AccessResult + GO"; + + List expectedMetadataList = new List + { + new ObjectMetadata + { + MetadataType = MetadataType.Table, + MetadataTypeName = "Table", + Name = "testTable1", + Schema = "dbo" + }, + new ObjectMetadata + { + MetadataType = MetadataType.SProc, + MetadataTypeName = "StoredProcedure", + Name = "testSp1", + Schema = "dbo" + }, + new ObjectMetadata + { + MetadataType = MetadataType.View, + MetadataTypeName = "View", + Name = "testView1", + Schema = "dbo" + }, + new ObjectMetadata + { + MetadataType = MetadataType.Function, + MetadataTypeName = "UserDefinedFunction", + Name = "testFun1", + Schema = "dbo" + }, + new ObjectMetadata + { + MetadataType = MetadataType.Function, + MetadataTypeName = "UserDefinedFunction", + Name = "testFun2", + Schema = "dbo" + } + }; + + await VerifyMetadataList(query, expectedMetadataList); + } + + private async Task VerifyMetadataList(string query, List expectedMetadataList) + { + var testDb = await SqlTestDb.CreateNewAsync(TestServerType.OnPrem, false, null, query, "MetadataTests"); + try + { + var requestContext = new Mock>(); + requestContext.Setup(x => x.SendResult(It.IsAny())).Returns(Task.FromResult(new object())); + ConnectionService connectionService = LiveConnectionHelper.GetLiveTestConnectionService(); + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + //Opening a connection to db to lock the db + TestConnectionResult connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync(testDb.DatabaseName, queryTempFile.FilePath, ConnectionType.Default); + + MetadataService service = new MetadataService(); + await service.HandleMetadataListRequest(new MetadataQueryParams + { + OwnerUri = queryTempFile.FilePath + }, requestContext.Object); + Thread.Sleep(2000); + await service.MetadataListTask; + + requestContext.Verify(x => x.SendResult(It.Is(r => VerifyResult(r, expectedMetadataList)))); + connectionService.Disconnect(new ServiceLayer.Connection.Contracts.DisconnectParams + { + OwnerUri = queryTempFile.FilePath + }); + } + } + catch + { + throw; + } + finally + { + await testDb.CleanupAsync(); + } + } + + private static bool VerifyResult(MetadataQueryResult result, List expectedMetadataList) + { + if (expectedMetadataList == null) + { + return result.Metadata == null; + } + + if(expectedMetadataList.Count() != result.Metadata.Count()) + { + return false; + } + foreach (ObjectMetadata expected in expectedMetadataList) + { + if (!result.Metadata.Any(x => x.MetadataType == expected.MetadataType && x.MetadataTypeName == expected.MetadataTypeName && x.Name == expected.Name && x.Schema == expected.Schema)) + { + return false; + } + } + return true; + } + } }