diff --git a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs index 8fbf2711..8f11ca05 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs @@ -32,6 +32,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlAssessment; using Microsoft.SqlTools.ServiceLayer.SqlContext; using Microsoft.SqlTools.ServiceLayer.Workspace; using Microsoft.SqlTools.ServiceLayer.NotebookConvert; +using Microsoft.SqlTools.ServiceLayer.ModelManagement; namespace Microsoft.SqlTools.ServiceLayer { @@ -137,7 +138,10 @@ namespace Microsoft.SqlTools.ServiceLayer ExternalLanguageService.Instance.InitializeService(serviceHost); serviceProvider.RegisterSingleService(ExternalLanguageService.Instance); - + + ModelManagementService.Instance.InitializeService(serviceHost); + serviceProvider.RegisterSingleService(ModelManagementService.Instance); + SqlAssessmentService.Instance.InitializeService(serviceHost); serviceProvider.RegisterSingleService(SqlAssessmentService.Instance); diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ConfigureModelTableRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ConfigureModelTableRequestParams.cs new file mode 100644 index 00000000..940f36f6 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ConfigureModelTableRequestParams.cs @@ -0,0 +1,30 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class ConfigureModelTableRequestParams : ModelRequestBase + { + } + + /// + /// Response class for get model + /// + public class ConfigureModelTableResponseParams : ModelResponseBase + { + } + + /// + /// Request class to get models + /// + public class ConfigureModelTableRequest + { + public static readonly + RequestType Type = + RequestType.Create("models/configure"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DeleteModelRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DeleteModelRequestParams.cs new file mode 100644 index 00000000..6bad3974 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DeleteModelRequestParams.cs @@ -0,0 +1,34 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class DeleteModelRequestParams : ModelRequestBase + { + /// + /// Model id + /// + public int ModelId { get; set; } + } + + /// + /// Response class for delete model + /// + public class DeleteModelResponseParams : ModelResponseBase + { + } + + /// + /// Request class to delete a model + /// + public class DeleteModelRequest + { + public static readonly + RequestType Type = + RequestType.Create("models/delete"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DownloadModelRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DownloadModelRequestParams.cs new file mode 100644 index 00000000..e3e70673 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DownloadModelRequestParams.cs @@ -0,0 +1,38 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class DownloadModelRequestParams : ModelRequestBase + { + /// + /// Model id + /// + public int ModelId { get; set; } + } + + /// + /// Response class for import model + /// + public class DownloadModelResponseParams : ModelResponseBase + { + /// + /// Downloaded file path + /// + public string FilePath { get; set; } + } + + /// + /// Request class to delete a model + /// + public class DownloadModelRequest + { + public static readonly + RequestType Type = + RequestType.Create("models/download"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/GetModelsRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/GetModelsRequestParams.cs new file mode 100644 index 00000000..133cdd12 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/GetModelsRequestParams.cs @@ -0,0 +1,32 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class GetModelsRequestParams : ModelRequestBase + { + } + + /// + /// Response class for get model + /// + public class GetModelsResponseParams : ModelResponseBase + { + public List Models { get; set; } + } + + /// + /// Request class to get models + /// + public class GetModelsRequest + { + public static readonly + RequestType Type = + RequestType.Create("models/get"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ImportModelRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ImportModelRequestParams.cs new file mode 100644 index 00000000..d63889b1 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ImportModelRequestParams.cs @@ -0,0 +1,34 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class ImportModelRequestParams : ModelRequestBase + { + /// + /// Model metadata + /// + public ModelMetadata Model { get; set; } + } + + /// + /// Response class for import model + /// + public class ImportModelResponseParams : ModelResponseBase + { + } + + /// + /// Request class to import a model + /// + public class ImportModelRequest + { + public static readonly + RequestType Type = + RequestType.Create("models/import"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelMetadata.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelMetadata.cs new file mode 100644 index 00000000..2f6709e2 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelMetadata.cs @@ -0,0 +1,75 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement +{ + /// + /// Model metadata + /// + public class ModelMetadata + { + /// + /// Model id + /// + public int Id { get; set; } + + /// + /// Model content length + /// + public Int64 ContentLength { get; set; } + + /// + /// Model name + /// + public string ModelName { get; set; } + + /// + /// Model created date + /// + public string Created { get; set; } + + /// + /// Model deployment time + /// + public string DeploymentTime { get; set; } + + /// + /// Model version + /// + public string Version { get; set; } + + /// + /// Model description + /// + public string Description { get; set; } + + /// + /// Model file path + /// + public string FilePath { get; set; } + + /// + /// Model framework + /// + public string Framework { get; set; } + + /// + /// Model framework version + /// + public string FrameworkVersion { get; set; } + + /// + /// Model run id + /// + public string RunId { get; set; } + + /// + /// Model deploy by + /// + public string DeployedBy { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelRequestBase.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelRequestBase.cs new file mode 100644 index 00000000..95e90a22 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelRequestBase.cs @@ -0,0 +1,36 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Utility; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class ModelRequestBase + { + /// + /// Models database name + /// + public string DatabaseName { get; set; } + + /// + /// The schema for model table + /// + public string SchemaName { get; set; } + + /// + /// Models table name + /// + public string TableName { get; set; } + + /// + /// Connection uri + /// + public string OwnerUri { get; set; } + } + + public class ModelResponseBase + { + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/UpdateModelRequestParams .cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/UpdateModelRequestParams .cs new file mode 100644 index 00000000..3fcde413 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/UpdateModelRequestParams .cs @@ -0,0 +1,30 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class UpdateModelRequestParams : ImportModelRequestParams + { + } + + /// + /// Response class for import model + /// + public class UpdateModelResponseParams : ModelResponseBase + { + } + + /// + /// Request class to import a model + /// + public class UpdateModelRequest + { + public static readonly + RequestType Type = + RequestType.Create("models/update"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/VerifyModelTableRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/VerifyModelTableRequestParams.cs new file mode 100644 index 00000000..a1fa0894 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/VerifyModelTableRequestParams.cs @@ -0,0 +1,35 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol.Contracts; +using System.Collections.Generic; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts +{ + public class VerifyModelTableRequestParams : ModelRequestBase + { + } + + /// + /// Response class for verify model table + /// + public class VerifyModelTableResponseParams : ModelResponseBase + { + /// + /// Specified is model table is verified + /// + public bool Verified { get; set; } + } + + /// + /// Request class to verify model table + /// + public class VerifyModelTableRequest + { + public static readonly + RequestType Type = + RequestType.Create("models/verify"); + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelManagementService.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelManagementService.cs new file mode 100644 index 00000000..8cb04c79 --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelManagementService.cs @@ -0,0 +1,241 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.Hosting; +using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts; +using Microsoft.SqlTools.Utility; +using System; +using System.Collections.Generic; +using System.Data; +using System.Diagnostics; +using System.Threading.Tasks; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement +{ + public class ModelManagementService + { + private ModelOperations serviceOperations = new ModelOperations(); + private ConnectionService connectionService = null; + private static readonly Lazy instance = new Lazy(() => new ModelManagementService()); + + + /// + /// Gets the singleton instance object + /// + public static ModelManagementService Instance + { + get { return instance.Value; } + } + + /// + /// Internal for testing purposes only + /// + internal ConnectionService ConnectionServiceInstance + { + get + { + if (connectionService == null) + { + connectionService = ConnectionService.Instance; + } + return connectionService; + } + + set + { + connectionService = value; + } + } + + public ModelOperations ModelOperations + { + get + { + return serviceOperations; + } + set + { + serviceOperations = value; + } + } + + public void InitializeService(ServiceHost serviceHost) + { + serviceHost.SetRequestHandler(ImportModelRequest.Type, this.HandleModelImportRequest); + serviceHost.SetRequestHandler(ConfigureModelTableRequest.Type, this.HandleConfigureModelTableRequest); + serviceHost.SetRequestHandler(DeleteModelRequest.Type, this.HandleDeleteModelRequest); + serviceHost.SetRequestHandler(DownloadModelRequest.Type, this.HandleDownloadModelRequest); + serviceHost.SetRequestHandler(GetModelsRequest.Type, this.HandleGetModelsRequest); + serviceHost.SetRequestHandler(UpdateModelRequest.Type, this.HandleUpdateModelRequest); + serviceHost.SetRequestHandler(VerifyModelTableRequest.Type, this.HandleVerifyModelTableRequest); + } + + /// + /// Handles import model request + /// + /// Request parameters + /// Request Context + public async Task HandleModelImportRequest(ImportModelRequestParams parameters, RequestContext requestContext) + { + Logger.Write(TraceEventType.Verbose, "HandleModelImportRequest"); + ImportModelResponseParams response = new ImportModelResponseParams + { + }; + + await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) => + { + ModelOperations.ImportModel(dbConnection, parameters); + return response; + }); + } + + /// + /// Handles get models request + /// + /// Request parameters + /// Request Context + public async Task HandleGetModelsRequest(GetModelsRequestParams parameters, RequestContext requestContext) + { + Logger.Write(TraceEventType.Verbose, "HandleGetModelsRequest"); + GetModelsResponseParams response = new GetModelsResponseParams + { + }; + + await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) => + { + List models = ModelOperations.GetModels(dbConnection, parameters); + response.Models = models; + return response; + }); + } + + /// + /// Handles update model request + /// + /// Request parameters + /// Request Context + public async Task HandleUpdateModelRequest(UpdateModelRequestParams parameters, RequestContext requestContext) + { + Logger.Write(TraceEventType.Verbose, "HandleUpdateModelRequest"); + UpdateModelResponseParams response = new UpdateModelResponseParams + { + }; + + await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) => + { + ModelOperations.UpdateModel(dbConnection, parameters); + return response; + }); + } + + /// + /// Handles delete model request + /// + /// Request parameters + /// Request Context + public async Task HandleDeleteModelRequest(DeleteModelRequestParams parameters, RequestContext requestContext) + { + Logger.Write(TraceEventType.Verbose, "HandleDeleteModelRequest"); + DeleteModelResponseParams response = new DeleteModelResponseParams + { + }; + + await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) => + { + ModelOperations.DeleteModel(dbConnection, parameters); + return response; + }); + } + + /// + /// Handles download model request + /// + /// Request parameters + /// Request Context + public async Task HandleDownloadModelRequest(DownloadModelRequestParams parameters, RequestContext requestContext) + { + Logger.Write(TraceEventType.Verbose, "HandleDownloadModelRequest"); + DownloadModelResponseParams response = new DownloadModelResponseParams + { + }; + + await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) => + { + response.FilePath = ModelOperations.DownloadModel(dbConnection, parameters); + return response; + }); + } + + /// + /// Handles verify model table request + /// + /// Request parameters + /// Request Context + public async Task HandleVerifyModelTableRequest(VerifyModelTableRequestParams parameters, RequestContext requestContext) + { + Logger.Write(TraceEventType.Verbose, "HandleVerifyModelTableRequest"); + VerifyModelTableResponseParams response = new VerifyModelTableResponseParams + { + }; + + await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) => + { + response.Verified = ModelOperations.VerifyImportTable(dbConnection, parameters); + return response; + }); + } + + /// + /// Handles configure model table request + /// + /// Request parameters + /// Request Context + public async Task HandleConfigureModelTableRequest(ConfigureModelTableRequestParams parameters, RequestContext requestContext) + { + Logger.Write(TraceEventType.Verbose, "HandleConfigureModelTableRequest"); + ConfigureModelTableResponseParams response = new ConfigureModelTableResponseParams(); + + await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) => + { + ModelOperations.ConfigureImportTable(dbConnection, parameters); + return response; + }); + } + + private async Task HandleRequest( + T parameters, + TResponse response, + RequestContext requestContext, + Func operation) where T : ModelRequestBase where TResponse : ModelResponseBase + { + try + { + ConnectionInfo connInfo; + ConnectionServiceInstance.TryFindConnection( + parameters.OwnerUri, + out connInfo); + if (connInfo == null) + { + await requestContext.SendError(new Exception(SR.ConnectionServiceDbErrorDefaultNotConnected(parameters.OwnerUri))); + } + else + { + using (IDbConnection dbConnection = ConnectionService.OpenSqlConnection(connInfo)) + { + response = operation(dbConnection, parameters, response); + } + await requestContext.SendResult(response); + } + } + catch (Exception e) + { + // Exception related to run task will be captured here + await requestContext.SendError(e); + } + } + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelOperations.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelOperations.cs new file mode 100644 index 00000000..1813789a --- /dev/null +++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelOperations.cs @@ -0,0 +1,416 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.ServiceLayer.Management; +using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts; +using Microsoft.SqlTools.ServiceLayer.Utility; +using System; +using System.Collections.Generic; +using System.Data; +using System.IO; + +namespace Microsoft.SqlTools.ServiceLayer.ModelManagement +{ + public class ModelOperations + { + /// + /// Returns models from given table + /// + /// Db connection + /// model request + /// Models + public virtual List GetModels(IDbConnection connection, ModelRequestBase request) + { + List models = new List(); + using (IDbCommand command = connection.CreateCommand()) + { + command.CommandText = GetSelectModelsQuery(request.DatabaseName, request.TableName, request.SchemaName); + + using (IDataReader reader = command.ExecuteReader()) + { + while (reader.Read()) + { + models.Add(LoadModelMetadata(reader)); + } + } + } + + return models; + } + + /// + /// Downlaods model content into a temp file and returns the file path + /// + /// Db connection + /// model request + /// Model file path + public virtual string DownloadModel(IDbConnection connection, DownloadModelRequestParams request) + { + string fileName = Path.GetTempFileName(); + using (IDbCommand command = connection.CreateCommand()) + { + Dictionary parameters = new Dictionary(); + command.CommandText = GetSelectModelContentQuery(request.DatabaseName, request.TableName, request.SchemaName, request.ModelId, parameters); + + foreach (var item in parameters) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = item.Key; + parameter.Value = item.Value; + command.Parameters.Add(parameter); + } + using (IDataReader reader = command.ExecuteReader()) + { + while (reader.Read()) + { + File.WriteAllBytes(fileName, (byte[])reader[0]); + } + } + } + + return fileName; + } + + /// + /// Import model to given table + /// + /// Db connection + /// model request + public virtual void ImportModel(IDbConnection connection, ImportModelRequestParams request) + { + WithDbChange(connection, request, (request) => + { + Dictionary parameters = new Dictionary(); + + using (IDbCommand command = connection.CreateCommand()) + { + command.CommandText = GetInsertModelQuery(request.TableName, request.SchemaName, request.Model, parameters); + + foreach (var item in parameters) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = item.Key; + parameter.Value = item.Value; + command.Parameters.Add(parameter); + } + command.ExecuteNonQuery(); + return true; + } + }); + } + + /// + /// Updates model + /// + /// Db connection + /// model request + public virtual void UpdateModel(IDbConnection connection, UpdateModelRequestParams request) + { + WithDbChange(connection, request, (request) => + { + Dictionary parameters = new Dictionary(); + using (IDbCommand command = connection.CreateCommand()) + { + command.CommandText = GetUpdateModelQuery(request.TableName, request.SchemaName, request.Model, parameters); + + foreach (var item in parameters) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = item.Key; + parameter.Value = item.Value; + command.Parameters.Add(parameter); + } + command.ExecuteNonQuery(); + return true; + } + }); + } + + /// + /// Deletes a model from the given table + /// + /// Db connection + /// model request + public virtual void DeleteModel(IDbConnection connection, DeleteModelRequestParams request) + { + WithDbChange(connection, request, (request) => + { + Dictionary parameters = new Dictionary(); + using (IDbCommand command = connection.CreateCommand()) + { + command.CommandText = GetDeleteModelQuery(request.TableName, request.SchemaName, request.ModelId, parameters); + + foreach (var item in parameters) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = item.Key; + parameter.Value = item.Value; + command.Parameters.Add(parameter); + } + command.ExecuteNonQuery(); + return true; + } + }); + } + + /// + /// Configures model table + /// + /// Db connection + /// model request + public virtual void ConfigureImportTable(IDbConnection connection, ModelRequestBase request) + { + WithDbChange(connection, request, (request) => + { + Dictionary parameters = new Dictionary(); + using (IDbCommand command = connection.CreateCommand()) + { + command.CommandText = GetCreateModelTableQuery(request.TableName, request.SchemaName); + + foreach (var item in parameters) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = item.Key; + parameter.Value = item.Value; + command.Parameters.Add(parameter); + } + command.ExecuteNonQuery(); + return true; + } + }); + } + + /// + /// Verifies model table + /// + /// Db connection + /// model request + public virtual bool VerifyImportTable(IDbConnection connection, ModelRequestBase request) + { + int result = WithDbChange(connection, request, (request) => + { + Dictionary parameters = new Dictionary(); + using (IDbCommand command = connection.CreateCommand()) + { + command.CommandText = GetConfigTableVerificationQuery(request.DatabaseName, request.TableName, request.SchemaName); + + command.ExecuteNonQuery(); + using (IDataReader reader = command.ExecuteReader()) + { + while (reader.Read()) + { + return reader.GetInt32(0); + } + } + return 0; + } + }); + + return result == 1; + } + + private TResult WithDbChange(IDbConnection connection, T request, Func operation) where T : ModelRequestBase + { + string currentDb = connection.Database; + if (connection.Database != request.DatabaseName) + { + connection.ChangeDatabase(request.DatabaseName); + } + TResult result = operation(request); + + if (connection.Database != currentDb) + { + connection.ChangeDatabase(currentDb); + } + return result; + } + + private ModelMetadata LoadModelMetadata(IDataReader reader) + { + return new ModelMetadata + { + Id = reader.GetInt32(0), + ModelName = reader.GetString(1), + Description = reader.IsDBNull(2) ? string.Empty : reader.GetString(2), + Version = reader.IsDBNull(3) ? string.Empty : reader.GetString(3), + Created = reader.IsDBNull(4) ? string.Empty : reader.GetDateTime(4).ToString(), + Framework = reader.IsDBNull(5) ? string.Empty : reader.GetString(5), + FrameworkVersion = reader.IsDBNull(6) ? string.Empty : reader.GetString(6), + DeploymentTime = reader.IsDBNull(7) ? string.Empty : reader.GetDateTime(7).ToString(), + DeployedBy = reader.IsDBNull(8) ? string.Empty : reader.GetString(8), + RunId = reader.IsDBNull(9) ? string.Empty : reader.GetString(9), + ContentLength = reader.GetInt64(10), + }; + } + + private const string ModelSelectColumns = @" + SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, User_Name(deployed_by), run_id, + len(model)"; + + private static string GetThreePartsTableName(string dbName, string tableName, string schemaName) + { + return $"[{CUtils.EscapeStringCBracket(dbName)}].[{CUtils.EscapeStringCBracket(schemaName)}].[{CUtils.EscapeStringCBracket(tableName)}]"; + } + + private static string GetTwoPartsTableName(string tableName, string schemaName) + { + return $"[{CUtils.EscapeStringCBracket(schemaName)}].[{CUtils.EscapeStringCBracket(tableName)}]"; + } + + private static string GetSelectModelsQuery(string dbName, string tableName, string schemaName) + { + return $@" + {ModelSelectColumns} + FROM {GetThreePartsTableName(dbName, tableName, schemaName)} + WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml' + ORDER BY model_id"; + } + + private static string GetConfigTableVerificationQuery(string dbName, string tableName, string schemaName) + { + string twoPartsTableName = GetTwoPartsTableName(CUtils.EscapeStringSQuote(tableName), CUtils.EscapeStringSQuote(schemaName)); + return $@" + IF NOT EXISTS ( + SELECT name + FROM sys.databases + WHERE name = N'{CUtils.EscapeStringSQuote(dbName)}' + ) + BEGIN + SELECT 0 + END + ELSE + BEGIN + USE [{CUtils.EscapeStringCBracket(dbName)}] + IF EXISTS + ( SELECT t.name, s.name + FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id + WHERE t.name = '{CUtils.EscapeStringSQuote(tableName)}' + AND s.name = '{CUtils.EscapeStringSQuote(schemaName)}' + ) + BEGIN + IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_name') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_id') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_description') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_framework') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_framework_version') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_version') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_creation_time') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_deployment_time') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='deployed_by') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='run_id') + BEGIN + SELECT 1 + END + ELSE + BEGIN + SELECT 0 + END + END + ELSE + SELECT 1 + END"; + } + + + private static string GetCreateModelTableQuery(string tableName, string schemaName) + { + return $@" + IF NOT EXISTS + ( SELECT t.name, s.name + FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id + WHERE t.name = '{CUtils.EscapeStringSQuote(tableName)}' + AND s.name = '{CUtils.EscapeStringSQuote(schemaName)}' + ) + BEGIN + CREATE TABLE {GetTwoPartsTableName(tableName, schemaName)} ( + [model_id] [int] IDENTITY(1,1) NOT NULL, + [model_name] [varchar](256) NOT NULL, + [model_framework] [varchar](256) NULL, + [model_framework_version] [varchar](256) NULL, + [model] [varbinary](max) NOT NULL, + [model_version] [varchar](256) NULL, + [model_creation_time] [datetime2] NULL, + [model_deployment_time] [datetime2] NULL, + [deployed_by] [int] NULL, + [model_description] [varchar](256) NULL, + [run_id] [varchar](256) NULL, + CONSTRAINT [{CUtils.EscapeStringCBracket(tableName)}_models_pk] PRIMARY KEY CLUSTERED + ( + [model_id] ASC + )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] + ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] + ALTER TABLE {GetTwoPartsTableName(tableName, schemaName)} ADD CONSTRAINT [{CUtils.EscapeStringCBracket(tableName)}_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time] + END +"; + } + + private static string GetInsertModelQuery(string tableName, string schemaName, ModelMetadata model, Dictionary parameters) + { + string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName); + + return $@" + INSERT INTO {twoPartsTableName} + (model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id, deployed_by) + VALUES ( + {DatabaseUtils.AddStringParameterForInsert(model.ModelName ?? "")}, + {DatabaseUtils.AddByteArrayParameterForInsert("Content", model.FilePath ?? "", parameters)}, + {DatabaseUtils.AddStringParameterForInsert(model.Version ?? "")}, + {DatabaseUtils.AddStringParameterForInsert(model.Description ?? "")}, + {DatabaseUtils.AddStringParameterForInsert(model.Created)}, + {DatabaseUtils.AddStringParameterForInsert(model.Framework ?? "")}, + {DatabaseUtils.AddStringParameterForInsert(model.FrameworkVersion ?? "")}, + {DatabaseUtils.AddStringParameterForInsert(model.RunId ?? "")}, + USER_ID (Current_User) + ) +"; + } + + private static string GetUpdateModelQuery(string tableName, string schemaName, ModelMetadata model, Dictionary parameters) + { + string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName); + parameters.Add(ModelIdParameterName, model.Id); + + return $@" + UPDATE {twoPartsTableName} + SET + {DatabaseUtils.AddStringParameterForUpdate("model_name", model.ModelName ?? "")}, + {DatabaseUtils.AddStringParameterForUpdate("model_version", model.Version ?? "")}, + {DatabaseUtils.AddStringParameterForUpdate("model_description", model.Description ?? "")}, + {DatabaseUtils.AddStringParameterForUpdate("model_creation_time", model.Created)}, + {DatabaseUtils.AddStringParameterForUpdate("model_framework", model.Framework ?? "")}, + {DatabaseUtils.AddStringParameterForUpdate("model_framework_version", model.FrameworkVersion ?? "")}, + {DatabaseUtils.AddStringParameterForUpdate("run_id", model.RunId ?? "")} + WHERE model_id = @{ModelIdParameterName} + +"; + } + + private static string GetDeleteModelQuery(string tableName, string schemaName, int modelId, Dictionary parameters) + { + string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName); + parameters.Add(ModelIdParameterName, modelId); + + return $@" + DELETE FROM {twoPartsTableName} + WHERE model_id = @{ModelIdParameterName} +"; + } + + private static string GetSelectModelContentQuery(string dbName, string tableName, string schemaName, int modelId, Dictionary parameters) + { + string threePartsTableName = GetThreePartsTableName(dbName, tableName, schemaName); + parameters.Add(ModelIdParameterName, modelId); + + return $@" + SELECT model + FROM {threePartsTableName} + WHERE model_id = @{ModelIdParameterName} +"; + } + + private const string ModelIdParameterName = "ModelId"; + } +} diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs index fd1428b8..c45b408e 100644 --- a/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs +++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs @@ -3,7 +3,10 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. // +using Microsoft.SqlTools.ServiceLayer.Management; using System; +using System.Collections.Generic; +using System.IO; namespace Microsoft.SqlTools.ServiceLayer.Utility { @@ -22,5 +25,45 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility string.Compare(databaseName, CommonConstants.ModelDatabaseName, StringComparison.OrdinalIgnoreCase) == 0 || string.Compare(databaseName, CommonConstants.TempDbDatabaseName, StringComparison.OrdinalIgnoreCase) == 0); } + + public static string AddStringParameterForInsert(string paramValue) + { + string value = string.IsNullOrWhiteSpace(paramValue) ? paramValue : CUtils.EscapeStringSQuote(paramValue); + return $"'{value}'"; + } + + public static string AddStringParameterForUpdate(string columnName, string paramValue) + { + string value = string.IsNullOrWhiteSpace(paramValue) ? paramValue : CUtils.EscapeStringSQuote(paramValue); + return $"{columnName} = N'{value}'"; + } + + public static string AddByteArrayParameterForUpdate(string columnName, string paramName, string fileName, Dictionary parameters) + { + byte[] contentBytes; + using (var stream = new FileStream(fileName, FileMode.Open, FileAccess.Read)) + { + using (var reader = new BinaryReader(stream)) + { + contentBytes = reader.ReadBytes((int)stream.Length); + } + } + parameters.Add($"{paramName}", contentBytes); + return $"{columnName} = @{paramName}"; + } + + public static string AddByteArrayParameterForInsert(string paramName, string fileName, Dictionary parameters) + { + byte[] contentBytes; + using (var stream = new FileStream(fileName, FileMode.Open, FileAccess.Read)) + { + using (var reader = new BinaryReader(stream)) + { + contentBytes = reader.ReadBytes((int)stream.Length); + } + } + parameters.Add($"{paramName}", contentBytes); + return $"@{paramName}"; + } } } diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelManagementServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelManagementServiceTests.cs new file mode 100644 index 00000000..a435882c --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelManagementServiceTests.cs @@ -0,0 +1,326 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.SqlTools.Extensibility; +using Microsoft.SqlTools.Hosting.Protocol; +using Microsoft.SqlTools.ServiceLayer.Connection.Contracts; +using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; +using Microsoft.SqlTools.ServiceLayer.ModelManagement; +using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using Microsoft.SqlTools.ServiceLayer.UnitTests; +using Moq; +using System; +using System.Collections.Generic; +using System.Data; +using System.Threading.Tasks; +using NUnit.Framework; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ModelManagement +{ + public class ModelManagementServiceTests : ServiceTestBase + { + + [Test] + public async Task VerifyDeleteModelRequest() + { + DeleteModelRequestParams requestParams = new DeleteModelRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name", + ModelId = 1 + }; + Mock operations = new Mock(); + operations.Setup(x => x.DeleteModel(It.IsAny(), requestParams)); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyRequst( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleDeleteModelRequest(requestParams, requestContext); + return null; + }, + verify: (actual => + { + Assert.NotNull(actual); + })); + } + + [Test] + public async Task VerifyImportModelRequest() + { + ImportModelRequestParams requestParams = new ImportModelRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name", + Model = new ModelMetadata() + }; + Mock operations = new Mock(); + operations.Setup(x => x.ImportModel(It.IsAny(), requestParams)); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyRequst( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleModelImportRequest(requestParams, requestContext); + return null; + }, + verify: (actual => + { + Assert.NotNull(actual); + })); + } + + [Test] + public async Task VerifyUpdateModelRequest() + { + UpdateModelRequestParams requestParams = new UpdateModelRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name", + Model = new ModelMetadata() + }; + Mock operations = new Mock(); + operations.Setup(x => x.UpdateModel(It.IsAny(), requestParams)); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyRequst( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleUpdateModelRequest(requestParams, requestContext); + return null; + }, + verify: (actual => + { + Assert.NotNull(actual); + })); + } + + [Test] + public async Task VerifyDownloadModelRequest() + { + DownloadModelRequestParams requestParams = new DownloadModelRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name", + ModelId = 1 + }; + Mock operations = new Mock(); + operations.Setup(x => x.DownloadModel(It.IsAny(), requestParams)).Returns(() => "file path"); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyRequst( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleDownloadModelRequest(requestParams, requestContext); + return null; + }, + verify: (actual => + { + Assert.NotNull(actual); + Assert.AreEqual(actual.FilePath, "file path"); + })); + } + + [Test] + public async Task VerifyModelTableRequest() + { + VerifyModelTableRequestParams requestParams = new VerifyModelTableRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name" + }; + Mock operations = new Mock(); + operations.Setup(x => x.VerifyImportTable(It.IsAny(), requestParams)).Returns(() => true); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyRequst( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleVerifyModelTableRequest(requestParams, requestContext); + return null; + }, + verify: (actual => + { + Assert.NotNull(actual); + Assert.AreEqual(actual.Verified, true); + })); + } + + [Test] + public async Task VerifyConfigureModelTableRequest() + { + ConfigureModelTableRequestParams requestParams = new ConfigureModelTableRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name" + }; + Mock operations = new Mock(); + operations.Setup(x => x.ConfigureImportTable(It.IsAny(), requestParams)); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyRequst( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleConfigureModelTableRequest(requestParams, requestContext); + return null; + }, + verify: (actual => + { + Assert.NotNull(actual); + })); + } + + [Test] + public async Task VerifyGetModelRequest() + { + GetModelsRequestParams requestParams = new GetModelsRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name", + }; + Mock operations = new Mock(); + operations.Setup(x => x.GetModels(It.IsAny(), requestParams)).Returns(() => new List { new ModelMetadata() }); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyRequst( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleGetModelsRequest(requestParams, requestContext); + return null; + }, + verify: (actual => + { + Assert.NotNull(actual); + Assert.True(actual.Models.Count == 1); + })); + } + + [Test] + public async Task VerifyRequestFailedResponse() + { + DeleteModelRequestParams requestParams = new DeleteModelRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name", + ModelId = 1 + }; + Mock operations = new Mock(); + operations.Setup(x => x.DeleteModel(It.IsAny(), requestParams)).Throws(new ApplicationException("error")); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyError( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = connectionUrl; + await service.HandleDeleteModelRequest(requestParams, requestContext); + return null; + }); + } + + [Test] + public async Task VerifyInvalidConnectionResponse() + { + DeleteModelRequestParams requestParams = new DeleteModelRequestParams + { + DatabaseName = "db name", + SchemaName = "dbo", + TableName = "table name", + ModelId = 1 + }; + Mock operations = new Mock(); + operations.Setup(x => x.DeleteModel(It.IsAny(), requestParams)).Throws(new ApplicationException("error")); + ModelManagementService service = new ModelManagementService() + { + ModelOperations = operations.Object + }; + + await VerifyError( + test: async (requestContext, connectionUrl) => + { + requestParams.OwnerUri = "Invalid connection uri"; + await service.HandleDeleteModelRequest(requestParams, requestContext); + return null; + }); + } + + public async Task VerifyRequst(Func, string, Task> test, Action verify) + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); + await RunAndVerify( + test: (requestContext) => test(requestContext, queryTempFile.FilePath), + verify: verify); + + ModelManagementService.Instance.ConnectionServiceInstance.Disconnect(new DisconnectParams + { + OwnerUri = queryTempFile.FilePath, + Type = ServiceLayer.Connection.ConnectionType.Default + }); + } + } + + public async Task VerifyError(Func, string, Task> test) + { + using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile()) + { + var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath); + await RunAndVerifyError( + test: (requestContext) => test(requestContext, queryTempFile.FilePath)); + + ModelManagementService.Instance.ConnectionServiceInstance.Disconnect(new DisconnectParams + { + OwnerUri = queryTempFile.FilePath, + Type = ServiceLayer.Connection.ConnectionType.Default + }); + } + } + + protected override RegisteredServiceProvider CreateServiceProviderWithMinServices() + { + return base.CreateProvider(); + } + } +} diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelOperationsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelOperationsTests.cs new file mode 100644 index 00000000..e51af553 --- /dev/null +++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelOperationsTests.cs @@ -0,0 +1,350 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using Microsoft.Data.SqlClient; +using Microsoft.SqlTools.ServiceLayer.Connection; +using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility; +using Microsoft.SqlTools.ServiceLayer.ModelManagement; +using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts; +using Microsoft.SqlTools.ServiceLayer.Test.Common; +using NUnit.Framework; +using System; +using System.Data; +using System.IO; +using System.Linq; + +namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ModelManagement +{ + public class ModelOperationsTests + { + [Test] + public void VerifyImportTableShouldReturnTrueGivenNoTable() + { + bool expected = true; + bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + ModelRequestBase request = new ModelRequestBase + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName + }; + return modelOperations.VerifyImportTable(dbConnection, request); + }); + + Assert.AreEqual(expected, actual); + } + + [Test] + public void ImportModelShouldImportSuccessfullyGivenValidModel() + { + string modelFilePath = Path.GetTempFileName(); + File.WriteAllText(modelFilePath, "Test model"); + ModelMetadata expected = new ModelMetadata + { + FilePath = modelFilePath, + Created = DateTime.Now.ToString(), + Description = "model description", + ModelName = "model name", + RunId = "run id", + Version = "1.2", + Framework = "ONNX", + FrameworkVersion = "1.3", + }; + ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + + ImportModelRequestParams request = new ImportModelRequestParams + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName, + Model = expected + }; + modelOperations.ConfigureImportTable(dbConnection, request); + modelOperations.ImportModel(dbConnection, request); + var models = modelOperations.GetModels(dbConnection, request); + return models.FirstOrDefault(x => x.ModelName == expected.ModelName); + }); + + Assert.IsNotNull(actual); + Assert.AreEqual(expected.RunId, actual.RunId); + Assert.AreEqual(expected.Description, actual.Description); + Assert.AreEqual(expected.Framework, actual.Framework); + Assert.AreEqual(expected.Version, actual.Version); + Assert.AreEqual(expected.FrameworkVersion, actual.FrameworkVersion); + Assert.IsFalse(string.IsNullOrWhiteSpace(actual.DeploymentTime)); + Assert.IsFalse(string.IsNullOrWhiteSpace(actual.DeployedBy)); + } + + [Test] + public void DownloadModelShouldDownloadSuccessfullyGivenValidModel() + { + string expected = "Test model"; + + string modelFilePath = Path.GetTempFileName(); + File.WriteAllText(modelFilePath, expected); + ModelMetadata model = new ModelMetadata + { + FilePath = modelFilePath, + Created = DateTime.Now.ToString(), + Description = "model description", + ModelName = "model name", + RunId = "run id", + Version = "1.2", + Framework = "ONNX", + FrameworkVersion = "1.3", + }; + string actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + ImportModelRequestParams request = new ImportModelRequestParams + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName, + Model = model + }; + modelOperations.ConfigureImportTable(dbConnection, request); + modelOperations.ImportModel(dbConnection, request); + var models = modelOperations.GetModels(dbConnection, request); + var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName); + Assert.IsNotNull(importedModel); + DownloadModelRequestParams downloadRequest = new DownloadModelRequestParams + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName, + ModelId = importedModel.Id + }; + string downloadedFile = modelOperations.DownloadModel(dbConnection, downloadRequest); + return File.ReadAllText(downloadedFile); + + }); + + Assert.IsNotNull(actual); + Assert.AreEqual(expected, actual); + } + + [Test] + public void UpdateModelShouldUpdateSuccessfullyGivenValidModel() + { + string modelFilePath = Path.GetTempFileName(); + File.WriteAllText(modelFilePath, "Test"); + ModelMetadata model = new ModelMetadata + { + FilePath = modelFilePath, + Created = DateTime.Now.ToString(), + Description = "model description", + ModelName = "model name", + RunId = "run id", + Version = "1.2", + Framework = "ONNX", + FrameworkVersion = "1.3", + }; + ModelMetadata expected = model; + + ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + ImportModelRequestParams request = new ImportModelRequestParams + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName, + Model = model + }; + modelOperations.ConfigureImportTable(dbConnection, request); + modelOperations.ImportModel(dbConnection, request); + var models = modelOperations.GetModels(dbConnection, request); + var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName); + Assert.IsNotNull(importedModel); + + UpdateModelRequestParams updateRequest = new UpdateModelRequestParams + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName, + Model = new ModelMetadata + { + Description = request.Model.Description + "updated", + ModelName = request.Model.ModelName + "updated", + Version = request.Model.Version + "updated", + Framework = request.Model.Framework + "updated", + RunId = request.Model.RunId + "updated", + FrameworkVersion = request.Model.FrameworkVersion + "updated", + Id = importedModel.Id + } + }; + modelOperations.UpdateModel(dbConnection, updateRequest); + models = modelOperations.GetModels(dbConnection, request); + var updatedModel = models.FirstOrDefault(x => x.ModelName == updateRequest.Model.ModelName); + Assert.IsNotNull(updatedModel); + return updatedModel; + }); + + Assert.IsNotNull(actual); + Assert.AreEqual(expected.Description + "updated", actual.Description); + Assert.AreEqual(expected.ModelName + "updated", actual.ModelName); + Assert.AreEqual(expected.Version + "updated", actual.Version); + Assert.AreEqual(expected.Framework + "updated", actual.Framework); + Assert.AreEqual(expected.FrameworkVersion + "updated", actual.FrameworkVersion); + Assert.AreEqual(expected.RunId + "updated", actual.RunId); + } + + [Test] + public void DeleteModelShouldDeleteSuccessfullyGivenValidModel() + { + string modelFilePath = Path.GetTempFileName(); + File.WriteAllText(modelFilePath, "Test"); + ModelMetadata model = new ModelMetadata + { + FilePath = modelFilePath, + Created = DateTime.Now.ToString(), + Description = "model description", + ModelName = "model name", + RunId = "run id", + Version = "1.2", + Framework = "ONNX", + FrameworkVersion = "1.3", + }; + ModelMetadata expected = model; + + ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + ImportModelRequestParams request = new ImportModelRequestParams + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName, + Model = model + }; + modelOperations.ConfigureImportTable(dbConnection, request); + modelOperations.ImportModel(dbConnection, request); + var models = modelOperations.GetModels(dbConnection, request); + var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName); + Assert.IsNotNull(importedModel); + + DeleteModelRequestParams deleteRequest = new DeleteModelRequestParams + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName, + ModelId = importedModel.Id + }; + modelOperations.DeleteModel(dbConnection, deleteRequest); + models = modelOperations.GetModels(dbConnection, request); + var updatedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName); + return updatedModel; + }); + + Assert.IsNull(actual); + } + + [Test] + public void VerifyImportTableShouldReturnTrueGivenTableCreatdByTheService() + { + bool expected = true; + bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + ModelRequestBase request = new ModelRequestBase + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName + }; + modelOperations.ConfigureImportTable(dbConnection, request); + return modelOperations.VerifyImportTable(dbConnection, request); + }); + + Assert.AreEqual(expected, actual); + } + + [Test] + public void VerifyImportTableShouldReturnTrueGivenTableNameThatNeedsEscaping() + { + bool expected = true; + bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + ModelRequestBase request = new ModelRequestBase + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName + }; + modelOperations.ConfigureImportTable(dbConnection, request); + return modelOperations.VerifyImportTable(dbConnection, request); + }, null, "models[]'"); + + Assert.AreEqual(expected, actual); + } + + [Test] + public void VerifyImportTableShouldReturnFalseGivenInvalidDbName() + { + Assert.Throws(typeof(SqlException), () => VerifyModelOperation((dbConnection, databaseName , tableName) => + { + ModelOperations modelOperations = new ModelOperations(); + ModelRequestBase request = new ModelRequestBase + { + DatabaseName = "invalidDb", + SchemaName = "dbo", + TableName = tableName + }; + modelOperations.ConfigureImportTable(dbConnection, request); + return modelOperations.VerifyImportTable(dbConnection, request); + })); + } + + [Test] + public void VerifyImportTableShouldReturnFalseGivenInvalidTable() + { + bool expected = false; + bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) => + { + dbConnection.ChangeDatabase(databaseName); + using (IDbCommand command = dbConnection.CreateCommand()) + { + command.CommandText = $"Create Table {tableName} (Id int, name varchar(10))"; + command.ExecuteNonQuery(); + } + dbConnection.ChangeDatabase("master"); + + ModelOperations modelOperations = new ModelOperations(); + ModelRequestBase request = new ModelRequestBase + { + DatabaseName = databaseName, + SchemaName = "dbo", + TableName = tableName + }; + return modelOperations.VerifyImportTable(dbConnection, request); + }); + + Assert.AreEqual(expected, actual); + } + + private T VerifyModelOperation(Func operation, string dbName = null, string tbName = null) + { + string databaseName = dbName ?? "testModels_" + new Random().Next(10000000, 99999999); + string tableName = tbName ?? "models"; + using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName)) + { + var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo(databaseName); + ConnectionInfo connInfo = liveConnection.ConnectionInfo; + IDbConnection dbConnection = ConnectionService.OpenSqlConnection(connInfo); + dbConnection.ChangeDatabase("master"); + + T result = operation(dbConnection, databaseName, tableName); + Assert.AreEqual(dbConnection.Database, "master"); + return result; + } + } + } +}