diff --git a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
index 8fbf2711..8f11ca05 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/HostLoader.cs
@@ -32,6 +32,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlAssessment;
using Microsoft.SqlTools.ServiceLayer.SqlContext;
using Microsoft.SqlTools.ServiceLayer.Workspace;
using Microsoft.SqlTools.ServiceLayer.NotebookConvert;
+using Microsoft.SqlTools.ServiceLayer.ModelManagement;
namespace Microsoft.SqlTools.ServiceLayer
{
@@ -137,7 +138,10 @@ namespace Microsoft.SqlTools.ServiceLayer
ExternalLanguageService.Instance.InitializeService(serviceHost);
serviceProvider.RegisterSingleService(ExternalLanguageService.Instance);
-
+
+ ModelManagementService.Instance.InitializeService(serviceHost);
+ serviceProvider.RegisterSingleService(ModelManagementService.Instance);
+
SqlAssessmentService.Instance.InitializeService(serviceHost);
serviceProvider.RegisterSingleService(SqlAssessmentService.Instance);
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ConfigureModelTableRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ConfigureModelTableRequestParams.cs
new file mode 100644
index 00000000..940f36f6
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ConfigureModelTableRequestParams.cs
@@ -0,0 +1,30 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class ConfigureModelTableRequestParams : ModelRequestBase
+ {
+ }
+
+ ///
+ /// Response class for get model
+ ///
+ public class ConfigureModelTableResponseParams : ModelResponseBase
+ {
+ }
+
+ ///
+ /// Request class to get models
+ ///
+ public class ConfigureModelTableRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("models/configure");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DeleteModelRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DeleteModelRequestParams.cs
new file mode 100644
index 00000000..6bad3974
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DeleteModelRequestParams.cs
@@ -0,0 +1,34 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class DeleteModelRequestParams : ModelRequestBase
+ {
+ ///
+ /// Model id
+ ///
+ public int ModelId { get; set; }
+ }
+
+ ///
+ /// Response class for delete model
+ ///
+ public class DeleteModelResponseParams : ModelResponseBase
+ {
+ }
+
+ ///
+ /// Request class to delete a model
+ ///
+ public class DeleteModelRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("models/delete");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DownloadModelRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DownloadModelRequestParams.cs
new file mode 100644
index 00000000..e3e70673
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/DownloadModelRequestParams.cs
@@ -0,0 +1,38 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class DownloadModelRequestParams : ModelRequestBase
+ {
+ ///
+ /// Model id
+ ///
+ public int ModelId { get; set; }
+ }
+
+ ///
+ /// Response class for import model
+ ///
+ public class DownloadModelResponseParams : ModelResponseBase
+ {
+ ///
+ /// Downloaded file path
+ ///
+ public string FilePath { get; set; }
+ }
+
+ ///
+ /// Request class to delete a model
+ ///
+ public class DownloadModelRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("models/download");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/GetModelsRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/GetModelsRequestParams.cs
new file mode 100644
index 00000000..133cdd12
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/GetModelsRequestParams.cs
@@ -0,0 +1,32 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+using System.Collections.Generic;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class GetModelsRequestParams : ModelRequestBase
+ {
+ }
+
+ ///
+ /// Response class for get model
+ ///
+ public class GetModelsResponseParams : ModelResponseBase
+ {
+ public List Models { get; set; }
+ }
+
+ ///
+ /// Request class to get models
+ ///
+ public class GetModelsRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("models/get");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ImportModelRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ImportModelRequestParams.cs
new file mode 100644
index 00000000..d63889b1
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ImportModelRequestParams.cs
@@ -0,0 +1,34 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class ImportModelRequestParams : ModelRequestBase
+ {
+ ///
+ /// Model metadata
+ ///
+ public ModelMetadata Model { get; set; }
+ }
+
+ ///
+ /// Response class for import model
+ ///
+ public class ImportModelResponseParams : ModelResponseBase
+ {
+ }
+
+ ///
+ /// Request class to import a model
+ ///
+ public class ImportModelRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("models/import");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelMetadata.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelMetadata.cs
new file mode 100644
index 00000000..2f6709e2
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelMetadata.cs
@@ -0,0 +1,75 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using System;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement
+{
+ ///
+ /// Model metadata
+ ///
+ public class ModelMetadata
+ {
+ ///
+ /// Model id
+ ///
+ public int Id { get; set; }
+
+ ///
+ /// Model content length
+ ///
+ public Int64 ContentLength { get; set; }
+
+ ///
+ /// Model name
+ ///
+ public string ModelName { get; set; }
+
+ ///
+ /// Model created date
+ ///
+ public string Created { get; set; }
+
+ ///
+ /// Model deployment time
+ ///
+ public string DeploymentTime { get; set; }
+
+ ///
+ /// Model version
+ ///
+ public string Version { get; set; }
+
+ ///
+ /// Model description
+ ///
+ public string Description { get; set; }
+
+ ///
+ /// Model file path
+ ///
+ public string FilePath { get; set; }
+
+ ///
+ /// Model framework
+ ///
+ public string Framework { get; set; }
+
+ ///
+ /// Model framework version
+ ///
+ public string FrameworkVersion { get; set; }
+
+ ///
+ /// Model run id
+ ///
+ public string RunId { get; set; }
+
+ ///
+ /// Model deploy by
+ ///
+ public string DeployedBy { get; set; }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelRequestBase.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelRequestBase.cs
new file mode 100644
index 00000000..95e90a22
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/ModelRequestBase.cs
@@ -0,0 +1,36 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.ServiceLayer.Utility;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class ModelRequestBase
+ {
+ ///
+ /// Models database name
+ ///
+ public string DatabaseName { get; set; }
+
+ ///
+ /// The schema for model table
+ ///
+ public string SchemaName { get; set; }
+
+ ///
+ /// Models table name
+ ///
+ public string TableName { get; set; }
+
+ ///
+ /// Connection uri
+ ///
+ public string OwnerUri { get; set; }
+ }
+
+ public class ModelResponseBase
+ {
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/UpdateModelRequestParams .cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/UpdateModelRequestParams .cs
new file mode 100644
index 00000000..3fcde413
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/UpdateModelRequestParams .cs
@@ -0,0 +1,30 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class UpdateModelRequestParams : ImportModelRequestParams
+ {
+ }
+
+ ///
+ /// Response class for import model
+ ///
+ public class UpdateModelResponseParams : ModelResponseBase
+ {
+ }
+
+ ///
+ /// Request class to import a model
+ ///
+ public class UpdateModelRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("models/update");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/VerifyModelTableRequestParams.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/VerifyModelTableRequestParams.cs
new file mode 100644
index 00000000..a1fa0894
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/Contracts/VerifyModelTableRequestParams.cs
@@ -0,0 +1,35 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol.Contracts;
+using System.Collections.Generic;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
+{
+ public class VerifyModelTableRequestParams : ModelRequestBase
+ {
+ }
+
+ ///
+ /// Response class for verify model table
+ ///
+ public class VerifyModelTableResponseParams : ModelResponseBase
+ {
+ ///
+ /// Specified is model table is verified
+ ///
+ public bool Verified { get; set; }
+ }
+
+ ///
+ /// Request class to verify model table
+ ///
+ public class VerifyModelTableRequest
+ {
+ public static readonly
+ RequestType Type =
+ RequestType.Create("models/verify");
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelManagementService.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelManagementService.cs
new file mode 100644
index 00000000..8cb04c79
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelManagementService.cs
@@ -0,0 +1,241 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.Hosting;
+using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
+using Microsoft.SqlTools.Utility;
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.Diagnostics;
+using System.Threading.Tasks;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement
+{
+ public class ModelManagementService
+ {
+ private ModelOperations serviceOperations = new ModelOperations();
+ private ConnectionService connectionService = null;
+ private static readonly Lazy instance = new Lazy(() => new ModelManagementService());
+
+
+ ///
+ /// Gets the singleton instance object
+ ///
+ public static ModelManagementService Instance
+ {
+ get { return instance.Value; }
+ }
+
+ ///
+ /// Internal for testing purposes only
+ ///
+ internal ConnectionService ConnectionServiceInstance
+ {
+ get
+ {
+ if (connectionService == null)
+ {
+ connectionService = ConnectionService.Instance;
+ }
+ return connectionService;
+ }
+
+ set
+ {
+ connectionService = value;
+ }
+ }
+
+ public ModelOperations ModelOperations
+ {
+ get
+ {
+ return serviceOperations;
+ }
+ set
+ {
+ serviceOperations = value;
+ }
+ }
+
+ public void InitializeService(ServiceHost serviceHost)
+ {
+ serviceHost.SetRequestHandler(ImportModelRequest.Type, this.HandleModelImportRequest);
+ serviceHost.SetRequestHandler(ConfigureModelTableRequest.Type, this.HandleConfigureModelTableRequest);
+ serviceHost.SetRequestHandler(DeleteModelRequest.Type, this.HandleDeleteModelRequest);
+ serviceHost.SetRequestHandler(DownloadModelRequest.Type, this.HandleDownloadModelRequest);
+ serviceHost.SetRequestHandler(GetModelsRequest.Type, this.HandleGetModelsRequest);
+ serviceHost.SetRequestHandler(UpdateModelRequest.Type, this.HandleUpdateModelRequest);
+ serviceHost.SetRequestHandler(VerifyModelTableRequest.Type, this.HandleVerifyModelTableRequest);
+ }
+
+ ///
+ /// Handles import model request
+ ///
+ /// Request parameters
+ /// Request Context
+ public async Task HandleModelImportRequest(ImportModelRequestParams parameters, RequestContext requestContext)
+ {
+ Logger.Write(TraceEventType.Verbose, "HandleModelImportRequest");
+ ImportModelResponseParams response = new ImportModelResponseParams
+ {
+ };
+
+ await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
+ {
+ ModelOperations.ImportModel(dbConnection, parameters);
+ return response;
+ });
+ }
+
+ ///
+ /// Handles get models request
+ ///
+ /// Request parameters
+ /// Request Context
+ public async Task HandleGetModelsRequest(GetModelsRequestParams parameters, RequestContext requestContext)
+ {
+ Logger.Write(TraceEventType.Verbose, "HandleGetModelsRequest");
+ GetModelsResponseParams response = new GetModelsResponseParams
+ {
+ };
+
+ await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
+ {
+ List models = ModelOperations.GetModels(dbConnection, parameters);
+ response.Models = models;
+ return response;
+ });
+ }
+
+ ///
+ /// Handles update model request
+ ///
+ /// Request parameters
+ /// Request Context
+ public async Task HandleUpdateModelRequest(UpdateModelRequestParams parameters, RequestContext requestContext)
+ {
+ Logger.Write(TraceEventType.Verbose, "HandleUpdateModelRequest");
+ UpdateModelResponseParams response = new UpdateModelResponseParams
+ {
+ };
+
+ await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
+ {
+ ModelOperations.UpdateModel(dbConnection, parameters);
+ return response;
+ });
+ }
+
+ ///
+ /// Handles delete model request
+ ///
+ /// Request parameters
+ /// Request Context
+ public async Task HandleDeleteModelRequest(DeleteModelRequestParams parameters, RequestContext requestContext)
+ {
+ Logger.Write(TraceEventType.Verbose, "HandleDeleteModelRequest");
+ DeleteModelResponseParams response = new DeleteModelResponseParams
+ {
+ };
+
+ await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
+ {
+ ModelOperations.DeleteModel(dbConnection, parameters);
+ return response;
+ });
+ }
+
+ ///
+ /// Handles download model request
+ ///
+ /// Request parameters
+ /// Request Context
+ public async Task HandleDownloadModelRequest(DownloadModelRequestParams parameters, RequestContext requestContext)
+ {
+ Logger.Write(TraceEventType.Verbose, "HandleDownloadModelRequest");
+ DownloadModelResponseParams response = new DownloadModelResponseParams
+ {
+ };
+
+ await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
+ {
+ response.FilePath = ModelOperations.DownloadModel(dbConnection, parameters);
+ return response;
+ });
+ }
+
+ ///
+ /// Handles verify model table request
+ ///
+ /// Request parameters
+ /// Request Context
+ public async Task HandleVerifyModelTableRequest(VerifyModelTableRequestParams parameters, RequestContext requestContext)
+ {
+ Logger.Write(TraceEventType.Verbose, "HandleVerifyModelTableRequest");
+ VerifyModelTableResponseParams response = new VerifyModelTableResponseParams
+ {
+ };
+
+ await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
+ {
+ response.Verified = ModelOperations.VerifyImportTable(dbConnection, parameters);
+ return response;
+ });
+ }
+
+ ///
+ /// Handles configure model table request
+ ///
+ /// Request parameters
+ /// Request Context
+ public async Task HandleConfigureModelTableRequest(ConfigureModelTableRequestParams parameters, RequestContext requestContext)
+ {
+ Logger.Write(TraceEventType.Verbose, "HandleConfigureModelTableRequest");
+ ConfigureModelTableResponseParams response = new ConfigureModelTableResponseParams();
+
+ await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
+ {
+ ModelOperations.ConfigureImportTable(dbConnection, parameters);
+ return response;
+ });
+ }
+
+ private async Task HandleRequest(
+ T parameters,
+ TResponse response,
+ RequestContext requestContext,
+ Func operation) where T : ModelRequestBase where TResponse : ModelResponseBase
+ {
+ try
+ {
+ ConnectionInfo connInfo;
+ ConnectionServiceInstance.TryFindConnection(
+ parameters.OwnerUri,
+ out connInfo);
+ if (connInfo == null)
+ {
+ await requestContext.SendError(new Exception(SR.ConnectionServiceDbErrorDefaultNotConnected(parameters.OwnerUri)));
+ }
+ else
+ {
+ using (IDbConnection dbConnection = ConnectionService.OpenSqlConnection(connInfo))
+ {
+ response = operation(dbConnection, parameters, response);
+ }
+ await requestContext.SendResult(response);
+ }
+ }
+ catch (Exception e)
+ {
+ // Exception related to run task will be captured here
+ await requestContext.SendError(e);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelOperations.cs b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelOperations.cs
new file mode 100644
index 00000000..1813789a
--- /dev/null
+++ b/src/Microsoft.SqlTools.ServiceLayer/ModelManagement/ModelOperations.cs
@@ -0,0 +1,416 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.ServiceLayer.Management;
+using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Utility;
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.IO;
+
+namespace Microsoft.SqlTools.ServiceLayer.ModelManagement
+{
+ public class ModelOperations
+ {
+ ///
+ /// Returns models from given table
+ ///
+ /// Db connection
+ /// model request
+ /// Models
+ public virtual List GetModels(IDbConnection connection, ModelRequestBase request)
+ {
+ List models = new List();
+ using (IDbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = GetSelectModelsQuery(request.DatabaseName, request.TableName, request.SchemaName);
+
+ using (IDataReader reader = command.ExecuteReader())
+ {
+ while (reader.Read())
+ {
+ models.Add(LoadModelMetadata(reader));
+ }
+ }
+ }
+
+ return models;
+ }
+
+ ///
+ /// Downlaods model content into a temp file and returns the file path
+ ///
+ /// Db connection
+ /// model request
+ /// Model file path
+ public virtual string DownloadModel(IDbConnection connection, DownloadModelRequestParams request)
+ {
+ string fileName = Path.GetTempFileName();
+ using (IDbCommand command = connection.CreateCommand())
+ {
+ Dictionary parameters = new Dictionary();
+ command.CommandText = GetSelectModelContentQuery(request.DatabaseName, request.TableName, request.SchemaName, request.ModelId, parameters);
+
+ foreach (var item in parameters)
+ {
+ var parameter = command.CreateParameter();
+ parameter.ParameterName = item.Key;
+ parameter.Value = item.Value;
+ command.Parameters.Add(parameter);
+ }
+ using (IDataReader reader = command.ExecuteReader())
+ {
+ while (reader.Read())
+ {
+ File.WriteAllBytes(fileName, (byte[])reader[0]);
+ }
+ }
+ }
+
+ return fileName;
+ }
+
+ ///
+ /// Import model to given table
+ ///
+ /// Db connection
+ /// model request
+ public virtual void ImportModel(IDbConnection connection, ImportModelRequestParams request)
+ {
+ WithDbChange(connection, request, (request) =>
+ {
+ Dictionary parameters = new Dictionary();
+
+ using (IDbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = GetInsertModelQuery(request.TableName, request.SchemaName, request.Model, parameters);
+
+ foreach (var item in parameters)
+ {
+ var parameter = command.CreateParameter();
+ parameter.ParameterName = item.Key;
+ parameter.Value = item.Value;
+ command.Parameters.Add(parameter);
+ }
+ command.ExecuteNonQuery();
+ return true;
+ }
+ });
+ }
+
+ ///
+ /// Updates model
+ ///
+ /// Db connection
+ /// model request
+ public virtual void UpdateModel(IDbConnection connection, UpdateModelRequestParams request)
+ {
+ WithDbChange(connection, request, (request) =>
+ {
+ Dictionary parameters = new Dictionary();
+ using (IDbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = GetUpdateModelQuery(request.TableName, request.SchemaName, request.Model, parameters);
+
+ foreach (var item in parameters)
+ {
+ var parameter = command.CreateParameter();
+ parameter.ParameterName = item.Key;
+ parameter.Value = item.Value;
+ command.Parameters.Add(parameter);
+ }
+ command.ExecuteNonQuery();
+ return true;
+ }
+ });
+ }
+
+ ///
+ /// Deletes a model from the given table
+ ///
+ /// Db connection
+ /// model request
+ public virtual void DeleteModel(IDbConnection connection, DeleteModelRequestParams request)
+ {
+ WithDbChange(connection, request, (request) =>
+ {
+ Dictionary parameters = new Dictionary();
+ using (IDbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = GetDeleteModelQuery(request.TableName, request.SchemaName, request.ModelId, parameters);
+
+ foreach (var item in parameters)
+ {
+ var parameter = command.CreateParameter();
+ parameter.ParameterName = item.Key;
+ parameter.Value = item.Value;
+ command.Parameters.Add(parameter);
+ }
+ command.ExecuteNonQuery();
+ return true;
+ }
+ });
+ }
+
+ ///
+ /// Configures model table
+ ///
+ /// Db connection
+ /// model request
+ public virtual void ConfigureImportTable(IDbConnection connection, ModelRequestBase request)
+ {
+ WithDbChange(connection, request, (request) =>
+ {
+ Dictionary parameters = new Dictionary();
+ using (IDbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = GetCreateModelTableQuery(request.TableName, request.SchemaName);
+
+ foreach (var item in parameters)
+ {
+ var parameter = command.CreateParameter();
+ parameter.ParameterName = item.Key;
+ parameter.Value = item.Value;
+ command.Parameters.Add(parameter);
+ }
+ command.ExecuteNonQuery();
+ return true;
+ }
+ });
+ }
+
+ ///
+ /// Verifies model table
+ ///
+ /// Db connection
+ /// model request
+ public virtual bool VerifyImportTable(IDbConnection connection, ModelRequestBase request)
+ {
+ int result = WithDbChange(connection, request, (request) =>
+ {
+ Dictionary parameters = new Dictionary();
+ using (IDbCommand command = connection.CreateCommand())
+ {
+ command.CommandText = GetConfigTableVerificationQuery(request.DatabaseName, request.TableName, request.SchemaName);
+
+ command.ExecuteNonQuery();
+ using (IDataReader reader = command.ExecuteReader())
+ {
+ while (reader.Read())
+ {
+ return reader.GetInt32(0);
+ }
+ }
+ return 0;
+ }
+ });
+
+ return result == 1;
+ }
+
+ private TResult WithDbChange(IDbConnection connection, T request, Func operation) where T : ModelRequestBase
+ {
+ string currentDb = connection.Database;
+ if (connection.Database != request.DatabaseName)
+ {
+ connection.ChangeDatabase(request.DatabaseName);
+ }
+ TResult result = operation(request);
+
+ if (connection.Database != currentDb)
+ {
+ connection.ChangeDatabase(currentDb);
+ }
+ return result;
+ }
+
+ private ModelMetadata LoadModelMetadata(IDataReader reader)
+ {
+ return new ModelMetadata
+ {
+ Id = reader.GetInt32(0),
+ ModelName = reader.GetString(1),
+ Description = reader.IsDBNull(2) ? string.Empty : reader.GetString(2),
+ Version = reader.IsDBNull(3) ? string.Empty : reader.GetString(3),
+ Created = reader.IsDBNull(4) ? string.Empty : reader.GetDateTime(4).ToString(),
+ Framework = reader.IsDBNull(5) ? string.Empty : reader.GetString(5),
+ FrameworkVersion = reader.IsDBNull(6) ? string.Empty : reader.GetString(6),
+ DeploymentTime = reader.IsDBNull(7) ? string.Empty : reader.GetDateTime(7).ToString(),
+ DeployedBy = reader.IsDBNull(8) ? string.Empty : reader.GetString(8),
+ RunId = reader.IsDBNull(9) ? string.Empty : reader.GetString(9),
+ ContentLength = reader.GetInt64(10),
+ };
+ }
+
+ private const string ModelSelectColumns = @"
+ SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, User_Name(deployed_by), run_id,
+ len(model)";
+
+ private static string GetThreePartsTableName(string dbName, string tableName, string schemaName)
+ {
+ return $"[{CUtils.EscapeStringCBracket(dbName)}].[{CUtils.EscapeStringCBracket(schemaName)}].[{CUtils.EscapeStringCBracket(tableName)}]";
+ }
+
+ private static string GetTwoPartsTableName(string tableName, string schemaName)
+ {
+ return $"[{CUtils.EscapeStringCBracket(schemaName)}].[{CUtils.EscapeStringCBracket(tableName)}]";
+ }
+
+ private static string GetSelectModelsQuery(string dbName, string tableName, string schemaName)
+ {
+ return $@"
+ {ModelSelectColumns}
+ FROM {GetThreePartsTableName(dbName, tableName, schemaName)}
+ WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
+ ORDER BY model_id";
+ }
+
+ private static string GetConfigTableVerificationQuery(string dbName, string tableName, string schemaName)
+ {
+ string twoPartsTableName = GetTwoPartsTableName(CUtils.EscapeStringSQuote(tableName), CUtils.EscapeStringSQuote(schemaName));
+ return $@"
+ IF NOT EXISTS (
+ SELECT name
+ FROM sys.databases
+ WHERE name = N'{CUtils.EscapeStringSQuote(dbName)}'
+ )
+ BEGIN
+ SELECT 0
+ END
+ ELSE
+ BEGIN
+ USE [{CUtils.EscapeStringCBracket(dbName)}]
+ IF EXISTS
+ ( SELECT t.name, s.name
+ FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
+ WHERE t.name = '{CUtils.EscapeStringSQuote(tableName)}'
+ AND s.name = '{CUtils.EscapeStringSQuote(schemaName)}'
+ )
+ BEGIN
+ IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_name')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_id')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_description')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_framework')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_framework_version')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_version')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_creation_time')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_deployment_time')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='deployed_by')
+ AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='run_id')
+ BEGIN
+ SELECT 1
+ END
+ ELSE
+ BEGIN
+ SELECT 0
+ END
+ END
+ ELSE
+ SELECT 1
+ END";
+ }
+
+
+ private static string GetCreateModelTableQuery(string tableName, string schemaName)
+ {
+ return $@"
+ IF NOT EXISTS
+ ( SELECT t.name, s.name
+ FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
+ WHERE t.name = '{CUtils.EscapeStringSQuote(tableName)}'
+ AND s.name = '{CUtils.EscapeStringSQuote(schemaName)}'
+ )
+ BEGIN
+ CREATE TABLE {GetTwoPartsTableName(tableName, schemaName)} (
+ [model_id] [int] IDENTITY(1,1) NOT NULL,
+ [model_name] [varchar](256) NOT NULL,
+ [model_framework] [varchar](256) NULL,
+ [model_framework_version] [varchar](256) NULL,
+ [model] [varbinary](max) NOT NULL,
+ [model_version] [varchar](256) NULL,
+ [model_creation_time] [datetime2] NULL,
+ [model_deployment_time] [datetime2] NULL,
+ [deployed_by] [int] NULL,
+ [model_description] [varchar](256) NULL,
+ [run_id] [varchar](256) NULL,
+ CONSTRAINT [{CUtils.EscapeStringCBracket(tableName)}_models_pk] PRIMARY KEY CLUSTERED
+ (
+ [model_id] ASC
+ )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
+ ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
+ ALTER TABLE {GetTwoPartsTableName(tableName, schemaName)} ADD CONSTRAINT [{CUtils.EscapeStringCBracket(tableName)}_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
+ END
+";
+ }
+
+ private static string GetInsertModelQuery(string tableName, string schemaName, ModelMetadata model, Dictionary parameters)
+ {
+ string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName);
+
+ return $@"
+ INSERT INTO {twoPartsTableName}
+ (model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id, deployed_by)
+ VALUES (
+ {DatabaseUtils.AddStringParameterForInsert(model.ModelName ?? "")},
+ {DatabaseUtils.AddByteArrayParameterForInsert("Content", model.FilePath ?? "", parameters)},
+ {DatabaseUtils.AddStringParameterForInsert(model.Version ?? "")},
+ {DatabaseUtils.AddStringParameterForInsert(model.Description ?? "")},
+ {DatabaseUtils.AddStringParameterForInsert(model.Created)},
+ {DatabaseUtils.AddStringParameterForInsert(model.Framework ?? "")},
+ {DatabaseUtils.AddStringParameterForInsert(model.FrameworkVersion ?? "")},
+ {DatabaseUtils.AddStringParameterForInsert(model.RunId ?? "")},
+ USER_ID (Current_User)
+ )
+";
+ }
+
+ private static string GetUpdateModelQuery(string tableName, string schemaName, ModelMetadata model, Dictionary parameters)
+ {
+ string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName);
+ parameters.Add(ModelIdParameterName, model.Id);
+
+ return $@"
+ UPDATE {twoPartsTableName}
+ SET
+ {DatabaseUtils.AddStringParameterForUpdate("model_name", model.ModelName ?? "")},
+ {DatabaseUtils.AddStringParameterForUpdate("model_version", model.Version ?? "")},
+ {DatabaseUtils.AddStringParameterForUpdate("model_description", model.Description ?? "")},
+ {DatabaseUtils.AddStringParameterForUpdate("model_creation_time", model.Created)},
+ {DatabaseUtils.AddStringParameterForUpdate("model_framework", model.Framework ?? "")},
+ {DatabaseUtils.AddStringParameterForUpdate("model_framework_version", model.FrameworkVersion ?? "")},
+ {DatabaseUtils.AddStringParameterForUpdate("run_id", model.RunId ?? "")}
+ WHERE model_id = @{ModelIdParameterName}
+
+";
+ }
+
+ private static string GetDeleteModelQuery(string tableName, string schemaName, int modelId, Dictionary parameters)
+ {
+ string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName);
+ parameters.Add(ModelIdParameterName, modelId);
+
+ return $@"
+ DELETE FROM {twoPartsTableName}
+ WHERE model_id = @{ModelIdParameterName}
+";
+ }
+
+ private static string GetSelectModelContentQuery(string dbName, string tableName, string schemaName, int modelId, Dictionary parameters)
+ {
+ string threePartsTableName = GetThreePartsTableName(dbName, tableName, schemaName);
+ parameters.Add(ModelIdParameterName, modelId);
+
+ return $@"
+ SELECT model
+ FROM {threePartsTableName}
+ WHERE model_id = @{ModelIdParameterName}
+";
+ }
+
+ private const string ModelIdParameterName = "ModelId";
+ }
+}
diff --git a/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs b/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs
index fd1428b8..c45b408e 100644
--- a/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs
+++ b/src/Microsoft.SqlTools.ServiceLayer/Utility/DatabaseUtils.cs
@@ -3,7 +3,10 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
//
+using Microsoft.SqlTools.ServiceLayer.Management;
using System;
+using System.Collections.Generic;
+using System.IO;
namespace Microsoft.SqlTools.ServiceLayer.Utility
{
@@ -22,5 +25,45 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
string.Compare(databaseName, CommonConstants.ModelDatabaseName, StringComparison.OrdinalIgnoreCase) == 0 ||
string.Compare(databaseName, CommonConstants.TempDbDatabaseName, StringComparison.OrdinalIgnoreCase) == 0);
}
+
+ public static string AddStringParameterForInsert(string paramValue)
+ {
+ string value = string.IsNullOrWhiteSpace(paramValue) ? paramValue : CUtils.EscapeStringSQuote(paramValue);
+ return $"'{value}'";
+ }
+
+ public static string AddStringParameterForUpdate(string columnName, string paramValue)
+ {
+ string value = string.IsNullOrWhiteSpace(paramValue) ? paramValue : CUtils.EscapeStringSQuote(paramValue);
+ return $"{columnName} = N'{value}'";
+ }
+
+ public static string AddByteArrayParameterForUpdate(string columnName, string paramName, string fileName, Dictionary parameters)
+ {
+ byte[] contentBytes;
+ using (var stream = new FileStream(fileName, FileMode.Open, FileAccess.Read))
+ {
+ using (var reader = new BinaryReader(stream))
+ {
+ contentBytes = reader.ReadBytes((int)stream.Length);
+ }
+ }
+ parameters.Add($"{paramName}", contentBytes);
+ return $"{columnName} = @{paramName}";
+ }
+
+ public static string AddByteArrayParameterForInsert(string paramName, string fileName, Dictionary parameters)
+ {
+ byte[] contentBytes;
+ using (var stream = new FileStream(fileName, FileMode.Open, FileAccess.Read))
+ {
+ using (var reader = new BinaryReader(stream))
+ {
+ contentBytes = reader.ReadBytes((int)stream.Length);
+ }
+ }
+ parameters.Add($"{paramName}", contentBytes);
+ return $"@{paramName}";
+ }
}
}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelManagementServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelManagementServiceTests.cs
new file mode 100644
index 00000000..a435882c
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelManagementServiceTests.cs
@@ -0,0 +1,326 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.SqlTools.Extensibility;
+using Microsoft.SqlTools.Hosting.Protocol;
+using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
+using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
+using Microsoft.SqlTools.ServiceLayer.ModelManagement;
+using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Test.Common;
+using Microsoft.SqlTools.ServiceLayer.UnitTests;
+using Moq;
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.Threading.Tasks;
+using NUnit.Framework;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ModelManagement
+{
+ public class ModelManagementServiceTests : ServiceTestBase
+ {
+
+ [Test]
+ public async Task VerifyDeleteModelRequest()
+ {
+ DeleteModelRequestParams requestParams = new DeleteModelRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name",
+ ModelId = 1
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.DeleteModel(It.IsAny(), requestParams));
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyRequst(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleDeleteModelRequest(requestParams, requestContext);
+ return null;
+ },
+ verify: (actual =>
+ {
+ Assert.NotNull(actual);
+ }));
+ }
+
+ [Test]
+ public async Task VerifyImportModelRequest()
+ {
+ ImportModelRequestParams requestParams = new ImportModelRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name",
+ Model = new ModelMetadata()
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.ImportModel(It.IsAny(), requestParams));
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyRequst(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleModelImportRequest(requestParams, requestContext);
+ return null;
+ },
+ verify: (actual =>
+ {
+ Assert.NotNull(actual);
+ }));
+ }
+
+ [Test]
+ public async Task VerifyUpdateModelRequest()
+ {
+ UpdateModelRequestParams requestParams = new UpdateModelRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name",
+ Model = new ModelMetadata()
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.UpdateModel(It.IsAny(), requestParams));
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyRequst(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleUpdateModelRequest(requestParams, requestContext);
+ return null;
+ },
+ verify: (actual =>
+ {
+ Assert.NotNull(actual);
+ }));
+ }
+
+ [Test]
+ public async Task VerifyDownloadModelRequest()
+ {
+ DownloadModelRequestParams requestParams = new DownloadModelRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name",
+ ModelId = 1
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.DownloadModel(It.IsAny(), requestParams)).Returns(() => "file path");
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyRequst(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleDownloadModelRequest(requestParams, requestContext);
+ return null;
+ },
+ verify: (actual =>
+ {
+ Assert.NotNull(actual);
+ Assert.AreEqual(actual.FilePath, "file path");
+ }));
+ }
+
+ [Test]
+ public async Task VerifyModelTableRequest()
+ {
+ VerifyModelTableRequestParams requestParams = new VerifyModelTableRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name"
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.VerifyImportTable(It.IsAny(), requestParams)).Returns(() => true);
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyRequst(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleVerifyModelTableRequest(requestParams, requestContext);
+ return null;
+ },
+ verify: (actual =>
+ {
+ Assert.NotNull(actual);
+ Assert.AreEqual(actual.Verified, true);
+ }));
+ }
+
+ [Test]
+ public async Task VerifyConfigureModelTableRequest()
+ {
+ ConfigureModelTableRequestParams requestParams = new ConfigureModelTableRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name"
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.ConfigureImportTable(It.IsAny(), requestParams));
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyRequst(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleConfigureModelTableRequest(requestParams, requestContext);
+ return null;
+ },
+ verify: (actual =>
+ {
+ Assert.NotNull(actual);
+ }));
+ }
+
+ [Test]
+ public async Task VerifyGetModelRequest()
+ {
+ GetModelsRequestParams requestParams = new GetModelsRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name",
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.GetModels(It.IsAny(), requestParams)).Returns(() => new List { new ModelMetadata() });
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyRequst(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleGetModelsRequest(requestParams, requestContext);
+ return null;
+ },
+ verify: (actual =>
+ {
+ Assert.NotNull(actual);
+ Assert.True(actual.Models.Count == 1);
+ }));
+ }
+
+ [Test]
+ public async Task VerifyRequestFailedResponse()
+ {
+ DeleteModelRequestParams requestParams = new DeleteModelRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name",
+ ModelId = 1
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.DeleteModel(It.IsAny(), requestParams)).Throws(new ApplicationException("error"));
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyError(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = connectionUrl;
+ await service.HandleDeleteModelRequest(requestParams, requestContext);
+ return null;
+ });
+ }
+
+ [Test]
+ public async Task VerifyInvalidConnectionResponse()
+ {
+ DeleteModelRequestParams requestParams = new DeleteModelRequestParams
+ {
+ DatabaseName = "db name",
+ SchemaName = "dbo",
+ TableName = "table name",
+ ModelId = 1
+ };
+ Mock operations = new Mock();
+ operations.Setup(x => x.DeleteModel(It.IsAny(), requestParams)).Throws(new ApplicationException("error"));
+ ModelManagementService service = new ModelManagementService()
+ {
+ ModelOperations = operations.Object
+ };
+
+ await VerifyError(
+ test: async (requestContext, connectionUrl) =>
+ {
+ requestParams.OwnerUri = "Invalid connection uri";
+ await service.HandleDeleteModelRequest(requestParams, requestContext);
+ return null;
+ });
+ }
+
+ public async Task VerifyRequst(Func, string, Task> test, Action verify)
+ {
+ using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
+ {
+ var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
+ await RunAndVerify(
+ test: (requestContext) => test(requestContext, queryTempFile.FilePath),
+ verify: verify);
+
+ ModelManagementService.Instance.ConnectionServiceInstance.Disconnect(new DisconnectParams
+ {
+ OwnerUri = queryTempFile.FilePath,
+ Type = ServiceLayer.Connection.ConnectionType.Default
+ });
+ }
+ }
+
+ public async Task VerifyError(Func, string, Task> test)
+ {
+ using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
+ {
+ var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
+ await RunAndVerifyError(
+ test: (requestContext) => test(requestContext, queryTempFile.FilePath));
+
+ ModelManagementService.Instance.ConnectionServiceInstance.Disconnect(new DisconnectParams
+ {
+ OwnerUri = queryTempFile.FilePath,
+ Type = ServiceLayer.Connection.ConnectionType.Default
+ });
+ }
+ }
+
+ protected override RegisteredServiceProvider CreateServiceProviderWithMinServices()
+ {
+ return base.CreateProvider();
+ }
+ }
+}
diff --git a/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelOperationsTests.cs b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelOperationsTests.cs
new file mode 100644
index 00000000..e51af553
--- /dev/null
+++ b/test/Microsoft.SqlTools.ServiceLayer.IntegrationTests/ModelManagement/ModelOperationsTests.cs
@@ -0,0 +1,350 @@
+//
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+//
+
+using Microsoft.Data.SqlClient;
+using Microsoft.SqlTools.ServiceLayer.Connection;
+using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
+using Microsoft.SqlTools.ServiceLayer.ModelManagement;
+using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
+using Microsoft.SqlTools.ServiceLayer.Test.Common;
+using NUnit.Framework;
+using System;
+using System.Data;
+using System.IO;
+using System.Linq;
+
+namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ModelManagement
+{
+ public class ModelOperationsTests
+ {
+ [Test]
+ public void VerifyImportTableShouldReturnTrueGivenNoTable()
+ {
+ bool expected = true;
+ bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+ ModelRequestBase request = new ModelRequestBase
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName
+ };
+ return modelOperations.VerifyImportTable(dbConnection, request);
+ });
+
+ Assert.AreEqual(expected, actual);
+ }
+
+ [Test]
+ public void ImportModelShouldImportSuccessfullyGivenValidModel()
+ {
+ string modelFilePath = Path.GetTempFileName();
+ File.WriteAllText(modelFilePath, "Test model");
+ ModelMetadata expected = new ModelMetadata
+ {
+ FilePath = modelFilePath,
+ Created = DateTime.Now.ToString(),
+ Description = "model description",
+ ModelName = "model name",
+ RunId = "run id",
+ Version = "1.2",
+ Framework = "ONNX",
+ FrameworkVersion = "1.3",
+ };
+ ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+
+ ImportModelRequestParams request = new ImportModelRequestParams
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName,
+ Model = expected
+ };
+ modelOperations.ConfigureImportTable(dbConnection, request);
+ modelOperations.ImportModel(dbConnection, request);
+ var models = modelOperations.GetModels(dbConnection, request);
+ return models.FirstOrDefault(x => x.ModelName == expected.ModelName);
+ });
+
+ Assert.IsNotNull(actual);
+ Assert.AreEqual(expected.RunId, actual.RunId);
+ Assert.AreEqual(expected.Description, actual.Description);
+ Assert.AreEqual(expected.Framework, actual.Framework);
+ Assert.AreEqual(expected.Version, actual.Version);
+ Assert.AreEqual(expected.FrameworkVersion, actual.FrameworkVersion);
+ Assert.IsFalse(string.IsNullOrWhiteSpace(actual.DeploymentTime));
+ Assert.IsFalse(string.IsNullOrWhiteSpace(actual.DeployedBy));
+ }
+
+ [Test]
+ public void DownloadModelShouldDownloadSuccessfullyGivenValidModel()
+ {
+ string expected = "Test model";
+
+ string modelFilePath = Path.GetTempFileName();
+ File.WriteAllText(modelFilePath, expected);
+ ModelMetadata model = new ModelMetadata
+ {
+ FilePath = modelFilePath,
+ Created = DateTime.Now.ToString(),
+ Description = "model description",
+ ModelName = "model name",
+ RunId = "run id",
+ Version = "1.2",
+ Framework = "ONNX",
+ FrameworkVersion = "1.3",
+ };
+ string actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+ ImportModelRequestParams request = new ImportModelRequestParams
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName,
+ Model = model
+ };
+ modelOperations.ConfigureImportTable(dbConnection, request);
+ modelOperations.ImportModel(dbConnection, request);
+ var models = modelOperations.GetModels(dbConnection, request);
+ var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
+ Assert.IsNotNull(importedModel);
+ DownloadModelRequestParams downloadRequest = new DownloadModelRequestParams
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName,
+ ModelId = importedModel.Id
+ };
+ string downloadedFile = modelOperations.DownloadModel(dbConnection, downloadRequest);
+ return File.ReadAllText(downloadedFile);
+
+ });
+
+ Assert.IsNotNull(actual);
+ Assert.AreEqual(expected, actual);
+ }
+
+ [Test]
+ public void UpdateModelShouldUpdateSuccessfullyGivenValidModel()
+ {
+ string modelFilePath = Path.GetTempFileName();
+ File.WriteAllText(modelFilePath, "Test");
+ ModelMetadata model = new ModelMetadata
+ {
+ FilePath = modelFilePath,
+ Created = DateTime.Now.ToString(),
+ Description = "model description",
+ ModelName = "model name",
+ RunId = "run id",
+ Version = "1.2",
+ Framework = "ONNX",
+ FrameworkVersion = "1.3",
+ };
+ ModelMetadata expected = model;
+
+ ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+ ImportModelRequestParams request = new ImportModelRequestParams
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName,
+ Model = model
+ };
+ modelOperations.ConfigureImportTable(dbConnection, request);
+ modelOperations.ImportModel(dbConnection, request);
+ var models = modelOperations.GetModels(dbConnection, request);
+ var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
+ Assert.IsNotNull(importedModel);
+
+ UpdateModelRequestParams updateRequest = new UpdateModelRequestParams
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName,
+ Model = new ModelMetadata
+ {
+ Description = request.Model.Description + "updated",
+ ModelName = request.Model.ModelName + "updated",
+ Version = request.Model.Version + "updated",
+ Framework = request.Model.Framework + "updated",
+ RunId = request.Model.RunId + "updated",
+ FrameworkVersion = request.Model.FrameworkVersion + "updated",
+ Id = importedModel.Id
+ }
+ };
+ modelOperations.UpdateModel(dbConnection, updateRequest);
+ models = modelOperations.GetModels(dbConnection, request);
+ var updatedModel = models.FirstOrDefault(x => x.ModelName == updateRequest.Model.ModelName);
+ Assert.IsNotNull(updatedModel);
+ return updatedModel;
+ });
+
+ Assert.IsNotNull(actual);
+ Assert.AreEqual(expected.Description + "updated", actual.Description);
+ Assert.AreEqual(expected.ModelName + "updated", actual.ModelName);
+ Assert.AreEqual(expected.Version + "updated", actual.Version);
+ Assert.AreEqual(expected.Framework + "updated", actual.Framework);
+ Assert.AreEqual(expected.FrameworkVersion + "updated", actual.FrameworkVersion);
+ Assert.AreEqual(expected.RunId + "updated", actual.RunId);
+ }
+
+ [Test]
+ public void DeleteModelShouldDeleteSuccessfullyGivenValidModel()
+ {
+ string modelFilePath = Path.GetTempFileName();
+ File.WriteAllText(modelFilePath, "Test");
+ ModelMetadata model = new ModelMetadata
+ {
+ FilePath = modelFilePath,
+ Created = DateTime.Now.ToString(),
+ Description = "model description",
+ ModelName = "model name",
+ RunId = "run id",
+ Version = "1.2",
+ Framework = "ONNX",
+ FrameworkVersion = "1.3",
+ };
+ ModelMetadata expected = model;
+
+ ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+ ImportModelRequestParams request = new ImportModelRequestParams
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName,
+ Model = model
+ };
+ modelOperations.ConfigureImportTable(dbConnection, request);
+ modelOperations.ImportModel(dbConnection, request);
+ var models = modelOperations.GetModels(dbConnection, request);
+ var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
+ Assert.IsNotNull(importedModel);
+
+ DeleteModelRequestParams deleteRequest = new DeleteModelRequestParams
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName,
+ ModelId = importedModel.Id
+ };
+ modelOperations.DeleteModel(dbConnection, deleteRequest);
+ models = modelOperations.GetModels(dbConnection, request);
+ var updatedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
+ return updatedModel;
+ });
+
+ Assert.IsNull(actual);
+ }
+
+ [Test]
+ public void VerifyImportTableShouldReturnTrueGivenTableCreatdByTheService()
+ {
+ bool expected = true;
+ bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+ ModelRequestBase request = new ModelRequestBase
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName
+ };
+ modelOperations.ConfigureImportTable(dbConnection, request);
+ return modelOperations.VerifyImportTable(dbConnection, request);
+ });
+
+ Assert.AreEqual(expected, actual);
+ }
+
+ [Test]
+ public void VerifyImportTableShouldReturnTrueGivenTableNameThatNeedsEscaping()
+ {
+ bool expected = true;
+ bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+ ModelRequestBase request = new ModelRequestBase
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName
+ };
+ modelOperations.ConfigureImportTable(dbConnection, request);
+ return modelOperations.VerifyImportTable(dbConnection, request);
+ }, null, "models[]'");
+
+ Assert.AreEqual(expected, actual);
+ }
+
+ [Test]
+ public void VerifyImportTableShouldReturnFalseGivenInvalidDbName()
+ {
+ Assert.Throws(typeof(SqlException), () => VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ ModelOperations modelOperations = new ModelOperations();
+ ModelRequestBase request = new ModelRequestBase
+ {
+ DatabaseName = "invalidDb",
+ SchemaName = "dbo",
+ TableName = tableName
+ };
+ modelOperations.ConfigureImportTable(dbConnection, request);
+ return modelOperations.VerifyImportTable(dbConnection, request);
+ }));
+ }
+
+ [Test]
+ public void VerifyImportTableShouldReturnFalseGivenInvalidTable()
+ {
+ bool expected = false;
+ bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
+ {
+ dbConnection.ChangeDatabase(databaseName);
+ using (IDbCommand command = dbConnection.CreateCommand())
+ {
+ command.CommandText = $"Create Table {tableName} (Id int, name varchar(10))";
+ command.ExecuteNonQuery();
+ }
+ dbConnection.ChangeDatabase("master");
+
+ ModelOperations modelOperations = new ModelOperations();
+ ModelRequestBase request = new ModelRequestBase
+ {
+ DatabaseName = databaseName,
+ SchemaName = "dbo",
+ TableName = tableName
+ };
+ return modelOperations.VerifyImportTable(dbConnection, request);
+ });
+
+ Assert.AreEqual(expected, actual);
+ }
+
+ private T VerifyModelOperation(Func operation, string dbName = null, string tbName = null)
+ {
+ string databaseName = dbName ?? "testModels_" + new Random().Next(10000000, 99999999);
+ string tableName = tbName ?? "models";
+ using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName))
+ {
+ var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo(databaseName);
+ ConnectionInfo connInfo = liveConnection.ConnectionInfo;
+ IDbConnection dbConnection = ConnectionService.OpenSqlConnection(connInfo);
+ dbConnection.ChangeDatabase("master");
+
+ T result = operation(dbConnection, databaseName, tableName);
+ Assert.AreEqual(dbConnection.Database, "master");
+ return result;
+ }
+ }
+ }
+}