From 08e855aa1d349a6a79765396de47c39a73312eba Mon Sep 17 00:00:00 2001 From: Lewis Sanchez <87730006+lewis-sanchez@users.noreply.github.com> Date: Tue, 22 Aug 2023 12:28:32 -0700 Subject: [PATCH] Implements Contextualization API into Azure Data Studio to get better query recommendations from extensions like Copilot (#2159) * Add contract to get all metadata request * Add new metadata service request endpoint * Adds factory to make database server scripts * Minor clean up * Corrects filename typo * Cleans up SmoScripterFactory * Stubs out metadata cacher * Method clean up * Add writing and reading to script cache * Cleans up request endpoint flow * Add missing edge case when cache isn't empty * Remove unused code * Remove unneeded null check * Read to end of stream * Passes correct parameter to write cache * Adds integration test to get all scripts * Renames new request endpoint * Rename request class to AllServerMetadataRequest * Renames server metadata request endpoints * Refresh cache and adjusts return obj type * Clean up * Assert table script generation * Minor cache refresh adjustment * Ensure test create table script is accurate * Code review changes * Additional code review changes * Swap logger write for logger warning * Renames generate request endpoint methods * Remove unused using statement * Remove unnecessary create table check * Check if previous script file is valid for reuse * Pascal case for method name * Code review changes * Fix PR issues * Update doc comment * Fixes tests after code review changes * Fix failing int. test due to 30 day temp file expiry * Generalize type names and update request endpoint * Updates doc comment. * Remove 'database' from type and method names * Code review changes * Code review changes * Issues with background thread. * Remove thread sleep for test reliability * Remove reflection from int. tests --- ...rateServerContextualizationNotification.cs | 26 +++ .../GetServerContextualizationRequest.cs | 31 +++ .../Metadata/MetadataScriptTempFileStream.cs | 140 ++++++++++++ .../Metadata/MetadataService.cs | 114 +++++++++- .../Metadata/SmoScripterHelpers.cs | 208 ++++++++++++++++++ .../Metadata/MetadataServiceTests.cs | 131 +++++++++-- 6 files changed, 627 insertions(+), 23 deletions(-) create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GenerateServerContextualizationNotification.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GetServerContextualizationRequest.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataScriptTempFileStream.cs create mode 100644 src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoScripterHelpers.cs diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GenerateServerContextualizationNotification.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GenerateServerContextualizationNotification.cs new file mode 100644 index 00000000..8285d005 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GenerateServerContextualizationNotification.cs @@ -0,0 +1,26 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Metadata.Contracts +{ + public class GenerateServerContextualizationParams + { + /// + /// The URI of the connection to generate context for. + /// + public string OwnerUri { get; set; } + } + + /// + /// Event set after a connection to a server is completed. + /// + public class GenerateServerContextualizationNotification + { + public static readonly EventType Type = + EventType.Create("metadata/generateServerContext"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GetServerContextualizationRequest.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GetServerContextualizationRequest.cs new file mode 100644 index 00000000..7bb76922 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/Contracts/GetServerContextualizationRequest.cs @@ -0,0 +1,31 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.Metadata.Contracts +{ + public class GetServerContextualizationParams + { + /// + /// The URI of the connection to generate scripts for. + /// + public string OwnerUri { get; set; } + } + + public class GetServerContextualizationResult + { + /// + /// An array containing the generated server context. + /// + public string[] Context { get; set; } + } + + public class GetServerContextualizationRequest + { + public static readonly RequestType Type = + RequestType.Create("metadata/getServerContext"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataScriptTempFileStream.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataScriptTempFileStream.cs new file mode 100644 index 00000000..91e067d3 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataScriptTempFileStream.cs @@ -0,0 +1,140 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Text; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.SqlTools.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.Metadata +{ + /// + /// This class is responsible for reading, writing, and checking the validity of script files. + /// + public static class MetadataScriptTempFileStream + { + private const short ScriptFileExpirationInDays = 30; + + /// + /// This method writes the passed in scripts to a temporary file. + /// + /// The name of the server which will go on to become the name of the file. + /// The generated scripts that will be written to the temporary file. + public static void Write(string serverName, IEnumerable scripts) + { + var encodedServerName = Base64Encode(serverName); + var tempFileName = $"{encodedServerName}.tmp"; + var generatedScripts = scripts.ToList(); + + try + { + var tempFilePath = Path.Combine(Path.GetTempPath(), tempFileName); + using (StreamWriter sw = new StreamWriter(tempFilePath, false)) + { + foreach (var script in generatedScripts) + { + sw.WriteLine(script); + } + } + } + catch (Exception ex) + { + Logger.Warning($"Failed to write scripts to temporary file. Error: {ex.Message}"); + throw; + } + } + + /// + /// Reads the scripts associated with the provided server name. + /// + /// The name of the server to retrieve the scripts for. + /// List containing all the scripts in the file. + public static IEnumerable Read(string serverName) + { + var encodedServerName = Base64Encode(serverName); + var tempFileName = $"{encodedServerName}.tmp"; + var scripts = new List(); + + try + { + var tempFilePath = Path.Combine(Path.GetTempPath(), tempFileName); + if (!File.Exists(tempFilePath)) + { + return scripts; + } + + using (StreamReader sr = new StreamReader(tempFilePath)) + { + while (!sr.EndOfStream) + { + var line = sr.ReadLine(); + if (!String.IsNullOrWhiteSpace(line)) + { + scripts.Add(line); + } + } + } + } + catch (Exception ex) + { + Logger.Warning($"Failed to read scripts from temporary file. Error: {ex.Message}"); + throw; + } + + return scripts; + } + + /// + /// Determines if the script file for a server is too old and needs to be updated + /// + /// The name of the file associated with the given server name. + /// True: The file was created within the expiration period; False: The script file needs to be created + /// or updated because it is too old. + public static bool IsScriptTempFileUpdateNeeded(string serverName) + { + var encodedServerName = Base64Encode(serverName); + var tempFileName = $"{encodedServerName}.tmp"; + + try + { + var tempFilePath = Path.Combine(Path.GetTempPath(), tempFileName); + if (!File.Exists(tempFilePath)) + { + return true; + } + else + { + /** + * Generated scripts don't need to be super up to date, so 30 days was chosen as the amount of time + * before the scripts are re-generated. This expiration date may change in the future, + * but for now this is what we're going with. + */ + var lastWriteTime = File.GetLastWriteTime(tempFilePath); + var isUpdateNeeded = (DateTime.Now - lastWriteTime).TotalDays < ScriptFileExpirationInDays ? false : true; + + return isUpdateNeeded; + } + } + catch (Exception ex) + { + Logger.Warning($"Unable to determine if the script file is older than {ScriptFileExpirationInDays} days. Error: {ex.Message}"); + throw; + } + } + + /// + /// Encodes a string to it's base 64 string representation. + /// + /// The string to base64 encode. + /// Base64 encoded string. + private static string Base64Encode(string str) + { + var bytes = Encoding.UTF8.GetBytes(str); + return Convert.ToBase64String(bytes); + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs index 4996852a..4e6a0e46 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/MetadataService.cs @@ -7,14 +7,17 @@ using System; using System.Collections.Generic; -using Microsoft.Data.SqlClient; +using System.Linq; +using System.Threading; using System.Threading.Tasks; +using Microsoft.Data.SqlClient; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.Hosting; using Microsoft.SqlTools.ServiceLayer.Metadata.Contracts; using Microsoft.SqlTools.ServiceLayer.Utility; using Microsoft.SqlTools.SqlCore.Metadata; +using Microsoft.SqlTools.Utility; namespace Microsoft.SqlTools.ServiceLayer.Metadata { @@ -56,6 +59,8 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata serviceHost.SetRequestHandler(MetadataListRequest.Type, HandleMetadataListRequest, true); serviceHost.SetRequestHandler(TableMetadataRequest.Type, HandleGetTableRequest, true); serviceHost.SetRequestHandler(ViewMetadataRequest.Type, HandleGetViewRequest, true); + serviceHost.SetEventHandler(GenerateServerContextualizationNotification.Type, HandleGenerateServerContextualizationNotification, true); + serviceHost.SetRequestHandler(GetServerContextualizationRequest.Type, HandleGetServerContextualizationRequest, true); } /// @@ -116,6 +121,113 @@ namespace Microsoft.SqlTools.ServiceLayer.Metadata await HandleGetTableOrViewRequest(metadataParams, "view", requestContext); } + /// + /// Handles the event for generating server contextualization scripts. + /// + internal static Task HandleGenerateServerContextualizationNotification(GenerateServerContextualizationParams contextualizationParams, + EventContext eventContext) + { + _ = Task.Factory.StartNew(() => + { + GenerateServerContextualization(contextualizationParams); + }, + CancellationToken.None, + TaskCreationOptions.None, + TaskScheduler.Default); + + return Task.CompletedTask; + } + + /// + /// Generates the contextualization scripts for a server. The generated context is in the form of create scripts for + /// database objects like tables and views. + /// + /// The contextualization parameters. + internal static void GenerateServerContextualization(GenerateServerContextualizationParams contextualizationParams) + { + MetadataService.ConnectionServiceInstance.TryFindConnection(contextualizationParams.OwnerUri, out ConnectionInfo connectionInfo); + + if (connectionInfo != null) + { + using (SqlConnection sqlConn = ConnectionService.OpenSqlConnection(connectionInfo, "metadata")) + { + // If scripts have been generated within the last 30 days then there isn't a need to go through the process + // of generating scripts again. + if (!MetadataScriptTempFileStream.IsScriptTempFileUpdateNeeded(connectionInfo.ConnectionDetails.ServerName)) + { + return; + } + + var scripts = SmoScripterHelpers.GenerateAllServerTableScripts(sqlConn); + if (scripts != null) + { + try + { + MetadataScriptTempFileStream.Write(connectionInfo.ConnectionDetails.ServerName, scripts); + } + catch (Exception ex) + { + Logger.Error($"An error was encountered while writing to the cache. Error: {ex.Message}"); + } + } + else + { + Logger.Error("Failed to generate server scripts"); + } + } + } + } + + /// + /// Handles the request for getting database server contextualization scripts. + /// + internal static Task HandleGetServerContextualizationRequest(GetServerContextualizationParams contextualizationParams, + RequestContext requestContext) + { + _ = Task.Factory.StartNew(async () => + { + await GetServerContextualization(contextualizationParams, requestContext); + }, + CancellationToken.None, + TaskCreationOptions.None, + TaskScheduler.Default); + + return Task.CompletedTask; + } + + /// + /// Gets server contextualization scripts. The retrieved scripts are create scripts for database objects like tables and views. + /// + /// The contextualization parameters to get context. + /// The request context for the request. + /// + internal static async Task GetServerContextualization(GetServerContextualizationParams contextualizationParams, RequestContext requestContext) + { + MetadataService.ConnectionServiceInstance.TryFindConnection(contextualizationParams.OwnerUri, out ConnectionInfo connectionInfo); + + if (connectionInfo != null) + { + try + { + var scripts = MetadataScriptTempFileStream.Read(connectionInfo.ConnectionDetails.ServerName); + await requestContext.SendResult(new GetServerContextualizationResult + { + Context = scripts.ToArray() + }); + } + catch (Exception ex) + { + Logger.Error("Failed to read scripts from the script cache"); + await requestContext.SendError(ex); + } + } + else + { + Logger.Error("Failed to find connection info about the server."); + await requestContext.SendError("Failed to find connection info about the server."); + } + } + /// /// Handle a table pr view metadata query request /// diff --git a/src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoScripterHelpers.cs b/src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoScripterHelpers.cs new file mode 100644 index 00000000..414913a2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/Metadata/SmoScripterHelpers.cs @@ -0,0 +1,208 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; +using System.Collections.Generic; +using System.Data.Common; +using Microsoft.Data.SqlClient; +using Microsoft.SqlServer.Management.Common; +using Microsoft.SqlServer.Management.Smo; +using Microsoft.SqlTools.ServiceLayer.Connection.ReliableConnection; +using Microsoft.SqlTools.SqlCore.Connection; +using Microsoft.SqlTools.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.Metadata +{ + internal static class SmoScripterHelpers + { + public static IEnumerable? GenerateAllServerTableScripts(DbConnection connection) + { + var serverConnection = SmoScripterHelpers.GetServerConnection(connection); + if (serverConnection == null) + { + return null; + } + + Server server = new Server(serverConnection); + var scripts = SmoScripterHelpers.GenerateTableScripts(server); + + return scripts; + } + + private static ServerConnection? GetServerConnection(DbConnection connection) + { + // Get a connection to the database for SMO purposes + var sqlConnection = connection as SqlConnection ?? SmoScripterHelpers.TryFindingReliableSqlConnection(connection as ReliableSqlConnection); + if (sqlConnection == null) + { + return null; + } + + var serverConnection = SmoScripterHelpers.ConnectToServerWithSmo(sqlConnection); + return serverConnection; + } + + private static SqlConnection? TryFindingReliableSqlConnection(ReliableSqlConnection reliableSqlConnection) + { + // It's not actually a SqlConnection, so let's try a reliable SQL connection + if (reliableSqlConnection == null) + { + // If we don't have connection we can use with SMO, just give up on using SMO + return null; + } + + // We have a reliable connection, use the underlying connection + return reliableSqlConnection.GetUnderlyingConnection(); + } + + private static ServerConnection ConnectToServerWithSmo(SqlConnection connection) + { + // Connect with SMO and get the metadata for the table + var serverConnection = (connection.AccessToken == null) + ? new ServerConnection(connection) + : new ServerConnection(connection, new AzureAccessToken(connection.AccessToken)); + + return serverConnection; + } + + private static IEnumerable GenerateTableScripts(Server server) + { + var urns = SmoScripterHelpers.GetAllServerTableAndViewUrns(server); + + var scriptingOptions = new ScriptingOptions + { + AgentAlertJob = false, + AgentJobId = false, + AgentNotify = false, + AllowSystemObjects = false, + AnsiFile = false, + AnsiPadding = false, + AppendToFile = false, + Bindings = false, + ChangeTracking = false, + ClusteredIndexes = false, + ColumnStoreIndexes = false, + ContinueScriptingOnError = true, + ConvertUserDefinedDataTypesToBaseType = false, + DdlBodyOnly = false, + DdlHeaderOnly = true, + DriAll = false, + DriAllConstraints = false, + DriAllKeys = false, + DriChecks = false, + DriClustered = false, + DriDefaults = false, + DriForeignKeys = false, + DriIncludeSystemNames = false, + DriIndexes = false, + DriNonClustered = false, + DriPrimaryKey = false, + DriUniqueKeys = false, + DriWithNoCheck = false, + EnforceScriptingOptions = true, + ExtendedProperties = false, + FullTextCatalogs = false, + FullTextIndexes = false, + FullTextStopLists = false, + IncludeDatabaseContext = false, + IncludeDatabaseRoleMemberships = false, + IncludeFullTextCatalogRootPath = false, + IncludeHeaders = false, + IncludeIfNotExists = false, + IncludeScriptingParametersHeader = false, + Indexes = false, + LoginSid = false, + NoAssemblies = true, + NoCollation = true, + NoCommandTerminator = true, + NoExecuteAs = true, + NoFileGroup = true, + NoFileStream = true, + NoFileStreamColumn = true, + NoIdentities = true, + NoIndexPartitioningSchemes = true, + NoMailProfileAccounts = true, + NoMailProfilePrincipals = true, + NonClusteredIndexes = false, + NoTablePartitioningSchemes = true, + NoVardecimal = false, + NoViewColumns = false, + NoXmlNamespaces = false, + OptimizerData = false, + Permissions = false, + PrimaryObject = true, + SchemaQualify = true, + SchemaQualifyForeignKeysReferences = true, + ScriptBatchTerminator = false, + ScriptData = false, + ScriptDataCompression = false, + ScriptDrops = false, + ScriptForAlter = false, + ScriptForCreateDrop = false, + ScriptForCreateOrAlter = true, + ScriptOwner = false, + ScriptSchema = true, + ScriptXmlCompression = false, + SpatialIndexes = false, + Statistics = false, + TimestampToBinary = false, + ToFileOnly = false, + Triggers = false, + WithDependencies = false, + XmlIndexes = false + }; + + var scripter = new Scripter(server); + scripter.Options = scriptingOptions; + var generatedScripts = scripter.Script(urns); + + var scripts = new List(); + foreach (var s in generatedScripts) + { + // Needed to remove '\r' and '\n' characters from script, so that an entire create script + // can be written and read as a single line to and from a temp file. Since scripts aren't + // going to be read by people, and mainly sent to Copilot to generate accurate suggestions, + // a lack of formatting is fine. + var script = s.Replace("\r", string.Empty).Replace("\n", string.Empty); + scripts.Add(script); + } + + return scripts; + } + + private static UrnCollection GetAllServerTableAndViewUrns(Server server) + { + UrnCollection urnCollection = new UrnCollection(); + + foreach (Database db in server.Databases) + { + try + { + foreach (SqlServer.Management.Smo.Table t in db.Tables) + { + urnCollection.Add(t.Urn); + } + } + catch (Exception ex) + { + Logger.Warning($"Unable to get table URNs. Error: {ex.Message}"); + } + + try + { + foreach (SqlServer.Management.Smo.View v in db.Views) + { + urnCollection.Add(v.Urn); + } + } + catch (Exception ex) + { + Logger.Warning($"Unable to get view URNs. Error: {ex.Message}"); + } + } + return urnCollection; + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs index a94afaef..d3daf185 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/Metadata/MetadataServiceTests.cs @@ -5,6 +5,14 @@ #nullable disable +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; using Microsoft.SqlTools.Hosting.Protocol; using Microsoft.SqlTools.ServiceLayer.Connection; using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; @@ -12,16 +20,10 @@ using Microsoft.SqlTools.ServiceLayer.Metadata; using Microsoft.SqlTools.ServiceLayer.Metadata.Contracts; using Microsoft.SqlTools.ServiceLayer.Test.Common; using Microsoft.SqlTools.ServiceLayer.Workspace.Contracts; +using Microsoft.SqlTools.SqlCore.Metadata; using Moq; -using System; -using System.Collections.Generic; -using Microsoft.Data.SqlClient; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; using NUnit.Framework; using static Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility.LiveConnectionHelper; -using Microsoft.SqlTools.SqlCore.Metadata; namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata { @@ -32,6 +34,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata { private string testTableSchema = "dbo"; private string testTableName = "MetadataTestTable"; + private string testTableName2 = "SecondMetadataTestTable"; private LiveConnectionHelper.TestConnectionResult GetLiveAutoCompleteTestObjects() { @@ -50,20 +53,20 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata return result; } - private void CreateTestTable(SqlConnection sqlConn) + private void CreateTestTable(SqlConnection sqlConn, string testTableSchema, string testTableName) { string sql = string.Format("IF OBJECT_ID('{0}.{1}', 'U') IS NULL CREATE TABLE {0}.{1}(id int)", - this.testTableSchema, this.testTableName); + testTableSchema, testTableName); using (var sqlCommand = new SqlCommand(sql, sqlConn)) { - sqlCommand.ExecuteNonQuery(); - } + sqlCommand.ExecuteNonQuery(); + } } - private void DeleteTestTable(SqlConnection sqlConn) + private void DeleteTestTable(SqlConnection sqlConn, string testTableSchema, string testTableName) { string sql = string.Format("IF OBJECT_ID('{0}.{1}', 'U') IS NOT NULL DROP TABLE {0}.{1}", - this.testTableSchema, this.testTableName); + testTableSchema, testTableName); using (var sqlCommand = new SqlCommand(sql, sqlConn)) { sqlCommand.ExecuteNonQuery(); @@ -82,7 +85,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata var sqlConn = ConnectionService.OpenSqlConnection(result.ConnectionInfo); Assert.NotNull(sqlConn); - CreateTestTable(sqlConn); + CreateTestTable(sqlConn, this.testTableSchema, this.testTableName); var metadata = new List(); MetadataService.ReadMetadata(sqlConn, metadata); @@ -100,18 +103,18 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata } Assert.True(foundTestTable); - DeleteTestTable(sqlConn); + DeleteTestTable(sqlConn, this.testTableSchema, this.testTableName); } [Test] public async Task GetTableInfoReturnsValidResults() { this.testTableName += new Random().Next(1000000, 9999999).ToString(); - + var result = GetLiveAutoCompleteTestObjects(); var sqlConn = ConnectionService.OpenSqlConnection(result.ConnectionInfo); - CreateTestTable(sqlConn); + CreateTestTable(sqlConn, this.testTableSchema, this.testTableName); var requestContext = new Mock>(); requestContext.Setup(x => x.SendResult(It.IsAny())).Returns(Task.FromResult(new object())); @@ -125,15 +128,99 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata await MetadataService.HandleGetTableRequest(metadataParmas, requestContext.Object); - DeleteTestTable(sqlConn); + DeleteTestTable(sqlConn, this.testTableSchema, this.testTableName); requestContext.VerifyAll(); } + [Test] + public async Task VerifyGenerateServerContextualizationNotification() + { + this.testTableName += new Random().Next(1000000, 9999999).ToString(); + this.testTableName2 += new Random().Next(0, 999999).ToString(); + + var connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(null); + var sqlConn = ConnectionService.OpenSqlConnection(connectionResult.ConnectionInfo); + + CreateTestTable(sqlConn, this.testTableSchema, this.testTableName); + CreateTestTable(sqlConn, this.testTableSchema, this.testTableName2); + + var generateServerContextualizationParams = new GenerateServerContextualizationParams + { + OwnerUri = connectionResult.ConnectionInfo.OwnerUri + }; + + MetadataService.GenerateServerContextualization(generateServerContextualizationParams); + + DeleteTestTable(sqlConn, this.testTableSchema, this.testTableName); + DeleteTestTable(sqlConn, this.testTableSchema, this.testTableName2); + + DeleteServerContextualizationTempFile(sqlConn.DataSource); + } + + [Test] + public async Task VerifyGetServerContextualizationRequest() + { + this.testTableName += new Random().Next(1000000, 9999999).ToString(); + this.testTableName2 += new Random().Next(1000000, 9999999).ToString(); + + var connectionResult = LiveConnectionHelper.InitLiveConnectionInfo(null); + var sqlConn = ConnectionService.OpenSqlConnection(connectionResult.ConnectionInfo); + + CreateTestTable(sqlConn, this.testTableSchema, this.testTableName); + CreateTestTable(sqlConn, this.testTableSchema, this.testTableName2); + + var generateServerContextualizationParams = new GenerateServerContextualizationParams + { + OwnerUri = connectionResult.ConnectionInfo.OwnerUri + }; + + MetadataService.GenerateServerContextualization(generateServerContextualizationParams); + + DeleteTestTable(sqlConn, this.testTableSchema, this.testTableName); + DeleteTestTable(sqlConn, this.testTableSchema, this.testTableName2); + + var firstCreateTableScript = $"CREATE TABLE [{this.testTableSchema}].[{this.testTableName}](\t[id] [int] NULL)"; + var secondCreateTableScript = $"CREATE TABLE [{this.testTableSchema}].[{this.testTableName2}](\t[id] [int] NULL)"; + + var mockGetServerContextualizationRequestContext = new Mock>(); + var actualGetServerContextualizationResponse = new GetServerContextualizationResult(); + mockGetServerContextualizationRequestContext.Setup(x => x.SendResult(It.IsAny())) + .Callback(actual => actualGetServerContextualizationResponse = actual) + .Returns(Task.CompletedTask); + + var getServerContextualizationParams = new GetServerContextualizationParams + { + OwnerUri = connectionResult.ConnectionInfo.OwnerUri + }; + + await MetadataService.GetServerContextualization(getServerContextualizationParams, mockGetServerContextualizationRequestContext.Object); + + Assert.IsTrue(actualGetServerContextualizationResponse.Context.Contains(firstCreateTableScript)); + Assert.IsTrue(actualGetServerContextualizationResponse.Context.Contains(secondCreateTableScript)); + + DeleteServerContextualizationTempFile(sqlConn.DataSource); + + mockGetServerContextualizationRequestContext.VerifyAll(); + } + + private void DeleteServerContextualizationTempFile(string serverName) + { + var bytes = Encoding.UTF8.GetBytes(serverName); + var encodedServerName = Convert.ToBase64String(bytes); + var tempFileName = $"{encodedServerName}.tmp"; + + var tempFilePath = Path.Combine(Path.GetTempPath(), tempFileName); + if (File.Exists(tempFilePath)) + { + File.Delete(tempFilePath); + } + } + [Test] public async Task GetViewInfoReturnsValidResults() - { - var result = GetLiveAutoCompleteTestObjects(); + { + var result = GetLiveAutoCompleteTestObjects(); var requestContext = new Mock>(); requestContext.Setup(x => x.SendResult(It.IsAny())).Returns(Task.FromResult(new object())); @@ -166,7 +253,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata AS RETURN SELECT 1 AS AccessResult GO"; - + List expectedMetadataList = new List { new ObjectMetadata @@ -254,7 +341,7 @@ namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.Metadata return result.Metadata == null; } - if(expectedMetadataList.Count != result.Metadata.Length) + if (expectedMetadataList.Count != result.Metadata.Length) { return false; }