mirror of
https://github.com/ckaczor/sqltoolsservice.git
synced 2026-02-17 02:51:45 -05:00
Added service for model management for machine learning extension (#1138)
* Added service for model management for ml extension
This commit is contained in:
@@ -32,6 +32,7 @@ using Microsoft.SqlTools.ServiceLayer.SqlAssessment;
|
|||||||
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
using Microsoft.SqlTools.ServiceLayer.SqlContext;
|
||||||
using Microsoft.SqlTools.ServiceLayer.Workspace;
|
using Microsoft.SqlTools.ServiceLayer.Workspace;
|
||||||
using Microsoft.SqlTools.ServiceLayer.NotebookConvert;
|
using Microsoft.SqlTools.ServiceLayer.NotebookConvert;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.ModelManagement;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ServiceLayer
|
namespace Microsoft.SqlTools.ServiceLayer
|
||||||
{
|
{
|
||||||
@@ -138,6 +139,9 @@ namespace Microsoft.SqlTools.ServiceLayer
|
|||||||
ExternalLanguageService.Instance.InitializeService(serviceHost);
|
ExternalLanguageService.Instance.InitializeService(serviceHost);
|
||||||
serviceProvider.RegisterSingleService(ExternalLanguageService.Instance);
|
serviceProvider.RegisterSingleService(ExternalLanguageService.Instance);
|
||||||
|
|
||||||
|
ModelManagementService.Instance.InitializeService(serviceHost);
|
||||||
|
serviceProvider.RegisterSingleService(ModelManagementService.Instance);
|
||||||
|
|
||||||
SqlAssessmentService.Instance.InitializeService(serviceHost);
|
SqlAssessmentService.Instance.InitializeService(serviceHost);
|
||||||
serviceProvider.RegisterSingleService(SqlAssessmentService.Instance);
|
serviceProvider.RegisterSingleService(SqlAssessmentService.Instance);
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class ConfigureModelTableRequestParams : ModelRequestBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Response class for get model
|
||||||
|
/// </summary>
|
||||||
|
public class ConfigureModelTableResponseParams : ModelResponseBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Request class to get models
|
||||||
|
/// </summary>
|
||||||
|
public class ConfigureModelTableRequest
|
||||||
|
{
|
||||||
|
public static readonly
|
||||||
|
RequestType<ConfigureModelTableRequestParams, ConfigureModelTableResponseParams> Type =
|
||||||
|
RequestType<ConfigureModelTableRequestParams, ConfigureModelTableResponseParams>.Create("models/configure");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class DeleteModelRequestParams : ModelRequestBase
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Model id
|
||||||
|
/// </summary>
|
||||||
|
public int ModelId { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Response class for delete model
|
||||||
|
/// </summary>
|
||||||
|
public class DeleteModelResponseParams : ModelResponseBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Request class to delete a model
|
||||||
|
/// </summary>
|
||||||
|
public class DeleteModelRequest
|
||||||
|
{
|
||||||
|
public static readonly
|
||||||
|
RequestType<DeleteModelRequestParams, DeleteModelResponseParams> Type =
|
||||||
|
RequestType<DeleteModelRequestParams, DeleteModelResponseParams>.Create("models/delete");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class DownloadModelRequestParams : ModelRequestBase
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Model id
|
||||||
|
/// </summary>
|
||||||
|
public int ModelId { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Response class for import model
|
||||||
|
/// </summary>
|
||||||
|
public class DownloadModelResponseParams : ModelResponseBase
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Downloaded file path
|
||||||
|
/// </summary>
|
||||||
|
public string FilePath { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Request class to delete a model
|
||||||
|
/// </summary>
|
||||||
|
public class DownloadModelRequest
|
||||||
|
{
|
||||||
|
public static readonly
|
||||||
|
RequestType<DownloadModelRequestParams, DownloadModelResponseParams> Type =
|
||||||
|
RequestType<DownloadModelRequestParams, DownloadModelResponseParams>.Create("models/download");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class GetModelsRequestParams : ModelRequestBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Response class for get model
|
||||||
|
/// </summary>
|
||||||
|
public class GetModelsResponseParams : ModelResponseBase
|
||||||
|
{
|
||||||
|
public List<ModelMetadata> Models { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Request class to get models
|
||||||
|
/// </summary>
|
||||||
|
public class GetModelsRequest
|
||||||
|
{
|
||||||
|
public static readonly
|
||||||
|
RequestType<GetModelsRequestParams, GetModelsResponseParams> Type =
|
||||||
|
RequestType<GetModelsRequestParams, GetModelsResponseParams>.Create("models/get");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class ImportModelRequestParams : ModelRequestBase
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Model metadata
|
||||||
|
/// </summary>
|
||||||
|
public ModelMetadata Model { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Response class for import model
|
||||||
|
/// </summary>
|
||||||
|
public class ImportModelResponseParams : ModelResponseBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Request class to import a model
|
||||||
|
/// </summary>
|
||||||
|
public class ImportModelRequest
|
||||||
|
{
|
||||||
|
public static readonly
|
||||||
|
RequestType<ImportModelRequestParams, ImportModelResponseParams> Type =
|
||||||
|
RequestType<ImportModelRequestParams, ImportModelResponseParams>.Create("models/import");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using System;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Model metadata
|
||||||
|
/// </summary>
|
||||||
|
public class ModelMetadata
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Model id
|
||||||
|
/// </summary>
|
||||||
|
public int Id { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model content length
|
||||||
|
/// </summary>
|
||||||
|
public Int64 ContentLength { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model name
|
||||||
|
/// </summary>
|
||||||
|
public string ModelName { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model created date
|
||||||
|
/// </summary>
|
||||||
|
public string Created { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model deployment time
|
||||||
|
/// </summary>
|
||||||
|
public string DeploymentTime { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model version
|
||||||
|
/// </summary>
|
||||||
|
public string Version { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model description
|
||||||
|
/// </summary>
|
||||||
|
public string Description { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model file path
|
||||||
|
/// </summary>
|
||||||
|
public string FilePath { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model framework
|
||||||
|
/// </summary>
|
||||||
|
public string Framework { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model framework version
|
||||||
|
/// </summary>
|
||||||
|
public string FrameworkVersion { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model run id
|
||||||
|
/// </summary>
|
||||||
|
public string RunId { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Model deploy by
|
||||||
|
/// </summary>
|
||||||
|
public string DeployedBy { get; set; }
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class ModelRequestBase
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Models database name
|
||||||
|
/// </summary>
|
||||||
|
public string DatabaseName { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The schema for model table
|
||||||
|
/// </summary>
|
||||||
|
public string SchemaName { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Models table name
|
||||||
|
/// </summary>
|
||||||
|
public string TableName { get; set; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Connection uri
|
||||||
|
/// </summary>
|
||||||
|
public string OwnerUri { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class ModelResponseBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class UpdateModelRequestParams : ImportModelRequestParams
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Response class for import model
|
||||||
|
/// </summary>
|
||||||
|
public class UpdateModelResponseParams : ModelResponseBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Request class to import a model
|
||||||
|
/// </summary>
|
||||||
|
public class UpdateModelRequest
|
||||||
|
{
|
||||||
|
public static readonly
|
||||||
|
RequestType<UpdateModelRequestParams, UpdateModelResponseParams> Type =
|
||||||
|
RequestType<UpdateModelRequestParams, UpdateModelResponseParams>.Create("models/update");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol.Contracts;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts
|
||||||
|
{
|
||||||
|
public class VerifyModelTableRequestParams : ModelRequestBase
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Response class for verify model table
|
||||||
|
/// </summary>
|
||||||
|
public class VerifyModelTableResponseParams : ModelResponseBase
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Specified is model table is verified
|
||||||
|
/// </summary>
|
||||||
|
public bool Verified { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Request class to verify model table
|
||||||
|
/// </summary>
|
||||||
|
public class VerifyModelTableRequest
|
||||||
|
{
|
||||||
|
public static readonly
|
||||||
|
RequestType<VerifyModelTableRequestParams, VerifyModelTableResponseParams> Type =
|
||||||
|
RequestType<VerifyModelTableRequestParams, VerifyModelTableResponseParams>.Create("models/verify");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,241 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Hosting;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
|
||||||
|
using Microsoft.SqlTools.Utility;
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Data;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement
|
||||||
|
{
|
||||||
|
public class ModelManagementService
|
||||||
|
{
|
||||||
|
private ModelOperations serviceOperations = new ModelOperations();
|
||||||
|
private ConnectionService connectionService = null;
|
||||||
|
private static readonly Lazy<ModelManagementService> instance = new Lazy<ModelManagementService>(() => new ModelManagementService());
|
||||||
|
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the singleton instance object
|
||||||
|
/// </summary>
|
||||||
|
public static ModelManagementService Instance
|
||||||
|
{
|
||||||
|
get { return instance.Value; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Internal for testing purposes only
|
||||||
|
/// </summary>
|
||||||
|
internal ConnectionService ConnectionServiceInstance
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
if (connectionService == null)
|
||||||
|
{
|
||||||
|
connectionService = ConnectionService.Instance;
|
||||||
|
}
|
||||||
|
return connectionService;
|
||||||
|
}
|
||||||
|
|
||||||
|
set
|
||||||
|
{
|
||||||
|
connectionService = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public ModelOperations ModelOperations
|
||||||
|
{
|
||||||
|
get
|
||||||
|
{
|
||||||
|
return serviceOperations;
|
||||||
|
}
|
||||||
|
set
|
||||||
|
{
|
||||||
|
serviceOperations = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void InitializeService(ServiceHost serviceHost)
|
||||||
|
{
|
||||||
|
serviceHost.SetRequestHandler(ImportModelRequest.Type, this.HandleModelImportRequest);
|
||||||
|
serviceHost.SetRequestHandler(ConfigureModelTableRequest.Type, this.HandleConfigureModelTableRequest);
|
||||||
|
serviceHost.SetRequestHandler(DeleteModelRequest.Type, this.HandleDeleteModelRequest);
|
||||||
|
serviceHost.SetRequestHandler(DownloadModelRequest.Type, this.HandleDownloadModelRequest);
|
||||||
|
serviceHost.SetRequestHandler(GetModelsRequest.Type, this.HandleGetModelsRequest);
|
||||||
|
serviceHost.SetRequestHandler(UpdateModelRequest.Type, this.HandleUpdateModelRequest);
|
||||||
|
serviceHost.SetRequestHandler(VerifyModelTableRequest.Type, this.HandleVerifyModelTableRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles import model request
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="parameters">Request parameters</param>
|
||||||
|
/// <param name="requestContext">Request Context</param>
|
||||||
|
public async Task HandleModelImportRequest(ImportModelRequestParams parameters, RequestContext<ImportModelResponseParams> requestContext)
|
||||||
|
{
|
||||||
|
Logger.Write(TraceEventType.Verbose, "HandleModelImportRequest");
|
||||||
|
ImportModelResponseParams response = new ImportModelResponseParams
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
|
||||||
|
{
|
||||||
|
ModelOperations.ImportModel(dbConnection, parameters);
|
||||||
|
return response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles get models request
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="parameters">Request parameters</param>
|
||||||
|
/// <param name="requestContext">Request Context</param>
|
||||||
|
public async Task HandleGetModelsRequest(GetModelsRequestParams parameters, RequestContext<GetModelsResponseParams> requestContext)
|
||||||
|
{
|
||||||
|
Logger.Write(TraceEventType.Verbose, "HandleGetModelsRequest");
|
||||||
|
GetModelsResponseParams response = new GetModelsResponseParams
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
|
||||||
|
{
|
||||||
|
List<ModelMetadata> models = ModelOperations.GetModels(dbConnection, parameters);
|
||||||
|
response.Models = models;
|
||||||
|
return response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles update model request
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="parameters">Request parameters</param>
|
||||||
|
/// <param name="requestContext">Request Context</param>
|
||||||
|
public async Task HandleUpdateModelRequest(UpdateModelRequestParams parameters, RequestContext<UpdateModelResponseParams> requestContext)
|
||||||
|
{
|
||||||
|
Logger.Write(TraceEventType.Verbose, "HandleUpdateModelRequest");
|
||||||
|
UpdateModelResponseParams response = new UpdateModelResponseParams
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
|
||||||
|
{
|
||||||
|
ModelOperations.UpdateModel(dbConnection, parameters);
|
||||||
|
return response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles delete model request
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="parameters">Request parameters</param>
|
||||||
|
/// <param name="requestContext">Request Context</param>
|
||||||
|
public async Task HandleDeleteModelRequest(DeleteModelRequestParams parameters, RequestContext<DeleteModelResponseParams> requestContext)
|
||||||
|
{
|
||||||
|
Logger.Write(TraceEventType.Verbose, "HandleDeleteModelRequest");
|
||||||
|
DeleteModelResponseParams response = new DeleteModelResponseParams
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
|
||||||
|
{
|
||||||
|
ModelOperations.DeleteModel(dbConnection, parameters);
|
||||||
|
return response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles download model request
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="parameters">Request parameters</param>
|
||||||
|
/// <param name="requestContext">Request Context</param>
|
||||||
|
public async Task HandleDownloadModelRequest(DownloadModelRequestParams parameters, RequestContext<DownloadModelResponseParams> requestContext)
|
||||||
|
{
|
||||||
|
Logger.Write(TraceEventType.Verbose, "HandleDownloadModelRequest");
|
||||||
|
DownloadModelResponseParams response = new DownloadModelResponseParams
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
|
||||||
|
{
|
||||||
|
response.FilePath = ModelOperations.DownloadModel(dbConnection, parameters);
|
||||||
|
return response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles verify model table request
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="parameters">Request parameters</param>
|
||||||
|
/// <param name="requestContext">Request Context</param>
|
||||||
|
public async Task HandleVerifyModelTableRequest(VerifyModelTableRequestParams parameters, RequestContext<VerifyModelTableResponseParams> requestContext)
|
||||||
|
{
|
||||||
|
Logger.Write(TraceEventType.Verbose, "HandleVerifyModelTableRequest");
|
||||||
|
VerifyModelTableResponseParams response = new VerifyModelTableResponseParams
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
|
||||||
|
{
|
||||||
|
response.Verified = ModelOperations.VerifyImportTable(dbConnection, parameters);
|
||||||
|
return response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Handles configure model table request
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="parameters">Request parameters</param>
|
||||||
|
/// <param name="requestContext">Request Context</param>
|
||||||
|
public async Task HandleConfigureModelTableRequest(ConfigureModelTableRequestParams parameters, RequestContext<ConfigureModelTableResponseParams> requestContext)
|
||||||
|
{
|
||||||
|
Logger.Write(TraceEventType.Verbose, "HandleConfigureModelTableRequest");
|
||||||
|
ConfigureModelTableResponseParams response = new ConfigureModelTableResponseParams();
|
||||||
|
|
||||||
|
await HandleRequest(parameters, response, requestContext, (dbConnection, parameters, response) =>
|
||||||
|
{
|
||||||
|
ModelOperations.ConfigureImportTable(dbConnection, parameters);
|
||||||
|
return response;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task HandleRequest<T, TResponse>(
|
||||||
|
T parameters,
|
||||||
|
TResponse response,
|
||||||
|
RequestContext<TResponse> requestContext,
|
||||||
|
Func<IDbConnection, T, TResponse, TResponse> operation) where T : ModelRequestBase where TResponse : ModelResponseBase
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
ConnectionInfo connInfo;
|
||||||
|
ConnectionServiceInstance.TryFindConnection(
|
||||||
|
parameters.OwnerUri,
|
||||||
|
out connInfo);
|
||||||
|
if (connInfo == null)
|
||||||
|
{
|
||||||
|
await requestContext.SendError(new Exception(SR.ConnectionServiceDbErrorDefaultNotConnected(parameters.OwnerUri)));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
using (IDbConnection dbConnection = ConnectionService.OpenSqlConnection(connInfo))
|
||||||
|
{
|
||||||
|
response = operation(dbConnection, parameters, response);
|
||||||
|
}
|
||||||
|
await requestContext.SendResult(response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (Exception e)
|
||||||
|
{
|
||||||
|
// Exception related to run task will be captured here
|
||||||
|
await requestContext.SendError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,416 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Management;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Utility;
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Data;
|
||||||
|
using System.IO;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.ModelManagement
|
||||||
|
{
|
||||||
|
public class ModelOperations
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Returns models from given table
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connection">Db connection</param>
|
||||||
|
/// <param name="request">model request</param>
|
||||||
|
/// <returns>Models</returns>
|
||||||
|
public virtual List<ModelMetadata> GetModels(IDbConnection connection, ModelRequestBase request)
|
||||||
|
{
|
||||||
|
List<ModelMetadata> models = new List<ModelMetadata>();
|
||||||
|
using (IDbCommand command = connection.CreateCommand())
|
||||||
|
{
|
||||||
|
command.CommandText = GetSelectModelsQuery(request.DatabaseName, request.TableName, request.SchemaName);
|
||||||
|
|
||||||
|
using (IDataReader reader = command.ExecuteReader())
|
||||||
|
{
|
||||||
|
while (reader.Read())
|
||||||
|
{
|
||||||
|
models.Add(LoadModelMetadata(reader));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Downlaods model content into a temp file and returns the file path
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connection">Db connection</param>
|
||||||
|
/// <param name="request">model request</param>
|
||||||
|
/// <returns>Model file path</returns>
|
||||||
|
public virtual string DownloadModel(IDbConnection connection, DownloadModelRequestParams request)
|
||||||
|
{
|
||||||
|
string fileName = Path.GetTempFileName();
|
||||||
|
using (IDbCommand command = connection.CreateCommand())
|
||||||
|
{
|
||||||
|
Dictionary<string, object> parameters = new Dictionary<string, object>();
|
||||||
|
command.CommandText = GetSelectModelContentQuery(request.DatabaseName, request.TableName, request.SchemaName, request.ModelId, parameters);
|
||||||
|
|
||||||
|
foreach (var item in parameters)
|
||||||
|
{
|
||||||
|
var parameter = command.CreateParameter();
|
||||||
|
parameter.ParameterName = item.Key;
|
||||||
|
parameter.Value = item.Value;
|
||||||
|
command.Parameters.Add(parameter);
|
||||||
|
}
|
||||||
|
using (IDataReader reader = command.ExecuteReader())
|
||||||
|
{
|
||||||
|
while (reader.Read())
|
||||||
|
{
|
||||||
|
File.WriteAllBytes(fileName, (byte[])reader[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fileName;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Import model to given table
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connection">Db connection</param>
|
||||||
|
/// <param name="request">model request</param>
|
||||||
|
public virtual void ImportModel(IDbConnection connection, ImportModelRequestParams request)
|
||||||
|
{
|
||||||
|
WithDbChange(connection, request, (request) =>
|
||||||
|
{
|
||||||
|
Dictionary<string, object> parameters = new Dictionary<string, object>();
|
||||||
|
|
||||||
|
using (IDbCommand command = connection.CreateCommand())
|
||||||
|
{
|
||||||
|
command.CommandText = GetInsertModelQuery(request.TableName, request.SchemaName, request.Model, parameters);
|
||||||
|
|
||||||
|
foreach (var item in parameters)
|
||||||
|
{
|
||||||
|
var parameter = command.CreateParameter();
|
||||||
|
parameter.ParameterName = item.Key;
|
||||||
|
parameter.Value = item.Value;
|
||||||
|
command.Parameters.Add(parameter);
|
||||||
|
}
|
||||||
|
command.ExecuteNonQuery();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Updates model
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connection">Db connection</param>
|
||||||
|
/// <param name="request">model request</param>
|
||||||
|
public virtual void UpdateModel(IDbConnection connection, UpdateModelRequestParams request)
|
||||||
|
{
|
||||||
|
WithDbChange(connection, request, (request) =>
|
||||||
|
{
|
||||||
|
Dictionary<string, object> parameters = new Dictionary<string, object>();
|
||||||
|
using (IDbCommand command = connection.CreateCommand())
|
||||||
|
{
|
||||||
|
command.CommandText = GetUpdateModelQuery(request.TableName, request.SchemaName, request.Model, parameters);
|
||||||
|
|
||||||
|
foreach (var item in parameters)
|
||||||
|
{
|
||||||
|
var parameter = command.CreateParameter();
|
||||||
|
parameter.ParameterName = item.Key;
|
||||||
|
parameter.Value = item.Value;
|
||||||
|
command.Parameters.Add(parameter);
|
||||||
|
}
|
||||||
|
command.ExecuteNonQuery();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Deletes a model from the given table
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connection">Db connection</param>
|
||||||
|
/// <param name="request">model request</param>
|
||||||
|
public virtual void DeleteModel(IDbConnection connection, DeleteModelRequestParams request)
|
||||||
|
{
|
||||||
|
WithDbChange(connection, request, (request) =>
|
||||||
|
{
|
||||||
|
Dictionary<string, object> parameters = new Dictionary<string, object>();
|
||||||
|
using (IDbCommand command = connection.CreateCommand())
|
||||||
|
{
|
||||||
|
command.CommandText = GetDeleteModelQuery(request.TableName, request.SchemaName, request.ModelId, parameters);
|
||||||
|
|
||||||
|
foreach (var item in parameters)
|
||||||
|
{
|
||||||
|
var parameter = command.CreateParameter();
|
||||||
|
parameter.ParameterName = item.Key;
|
||||||
|
parameter.Value = item.Value;
|
||||||
|
command.Parameters.Add(parameter);
|
||||||
|
}
|
||||||
|
command.ExecuteNonQuery();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Configures model table
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connection">Db connection</param>
|
||||||
|
/// <param name="request">model request</param>
|
||||||
|
public virtual void ConfigureImportTable(IDbConnection connection, ModelRequestBase request)
|
||||||
|
{
|
||||||
|
WithDbChange(connection, request, (request) =>
|
||||||
|
{
|
||||||
|
Dictionary<string, object> parameters = new Dictionary<string, object>();
|
||||||
|
using (IDbCommand command = connection.CreateCommand())
|
||||||
|
{
|
||||||
|
command.CommandText = GetCreateModelTableQuery(request.TableName, request.SchemaName);
|
||||||
|
|
||||||
|
foreach (var item in parameters)
|
||||||
|
{
|
||||||
|
var parameter = command.CreateParameter();
|
||||||
|
parameter.ParameterName = item.Key;
|
||||||
|
parameter.Value = item.Value;
|
||||||
|
command.Parameters.Add(parameter);
|
||||||
|
}
|
||||||
|
command.ExecuteNonQuery();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Verifies model table
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="connection">Db connection</param>
|
||||||
|
/// <param name="request">model request</param>
|
||||||
|
public virtual bool VerifyImportTable(IDbConnection connection, ModelRequestBase request)
|
||||||
|
{
|
||||||
|
int result = WithDbChange(connection, request, (request) =>
|
||||||
|
{
|
||||||
|
Dictionary<string, object> parameters = new Dictionary<string, object>();
|
||||||
|
using (IDbCommand command = connection.CreateCommand())
|
||||||
|
{
|
||||||
|
command.CommandText = GetConfigTableVerificationQuery(request.DatabaseName, request.TableName, request.SchemaName);
|
||||||
|
|
||||||
|
command.ExecuteNonQuery();
|
||||||
|
using (IDataReader reader = command.ExecuteReader())
|
||||||
|
{
|
||||||
|
while (reader.Read())
|
||||||
|
{
|
||||||
|
return reader.GetInt32(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return result == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
private TResult WithDbChange<T, TResult>(IDbConnection connection, T request, Func<T, TResult> operation) where T : ModelRequestBase
|
||||||
|
{
|
||||||
|
string currentDb = connection.Database;
|
||||||
|
if (connection.Database != request.DatabaseName)
|
||||||
|
{
|
||||||
|
connection.ChangeDatabase(request.DatabaseName);
|
||||||
|
}
|
||||||
|
TResult result = operation(request);
|
||||||
|
|
||||||
|
if (connection.Database != currentDb)
|
||||||
|
{
|
||||||
|
connection.ChangeDatabase(currentDb);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ModelMetadata LoadModelMetadata(IDataReader reader)
|
||||||
|
{
|
||||||
|
return new ModelMetadata
|
||||||
|
{
|
||||||
|
Id = reader.GetInt32(0),
|
||||||
|
ModelName = reader.GetString(1),
|
||||||
|
Description = reader.IsDBNull(2) ? string.Empty : reader.GetString(2),
|
||||||
|
Version = reader.IsDBNull(3) ? string.Empty : reader.GetString(3),
|
||||||
|
Created = reader.IsDBNull(4) ? string.Empty : reader.GetDateTime(4).ToString(),
|
||||||
|
Framework = reader.IsDBNull(5) ? string.Empty : reader.GetString(5),
|
||||||
|
FrameworkVersion = reader.IsDBNull(6) ? string.Empty : reader.GetString(6),
|
||||||
|
DeploymentTime = reader.IsDBNull(7) ? string.Empty : reader.GetDateTime(7).ToString(),
|
||||||
|
DeployedBy = reader.IsDBNull(8) ? string.Empty : reader.GetString(8),
|
||||||
|
RunId = reader.IsDBNull(9) ? string.Empty : reader.GetString(9),
|
||||||
|
ContentLength = reader.GetInt64(10),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private const string ModelSelectColumns = @"
|
||||||
|
SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, User_Name(deployed_by), run_id,
|
||||||
|
len(model)";
|
||||||
|
|
||||||
|
private static string GetThreePartsTableName(string dbName, string tableName, string schemaName)
|
||||||
|
{
|
||||||
|
return $"[{CUtils.EscapeStringCBracket(dbName)}].[{CUtils.EscapeStringCBracket(schemaName)}].[{CUtils.EscapeStringCBracket(tableName)}]";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string GetTwoPartsTableName(string tableName, string schemaName)
|
||||||
|
{
|
||||||
|
return $"[{CUtils.EscapeStringCBracket(schemaName)}].[{CUtils.EscapeStringCBracket(tableName)}]";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string GetSelectModelsQuery(string dbName, string tableName, string schemaName)
|
||||||
|
{
|
||||||
|
return $@"
|
||||||
|
{ModelSelectColumns}
|
||||||
|
FROM {GetThreePartsTableName(dbName, tableName, schemaName)}
|
||||||
|
WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
|
||||||
|
ORDER BY model_id";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string GetConfigTableVerificationQuery(string dbName, string tableName, string schemaName)
|
||||||
|
{
|
||||||
|
string twoPartsTableName = GetTwoPartsTableName(CUtils.EscapeStringSQuote(tableName), CUtils.EscapeStringSQuote(schemaName));
|
||||||
|
return $@"
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT name
|
||||||
|
FROM sys.databases
|
||||||
|
WHERE name = N'{CUtils.EscapeStringSQuote(dbName)}'
|
||||||
|
)
|
||||||
|
BEGIN
|
||||||
|
SELECT 0
|
||||||
|
END
|
||||||
|
ELSE
|
||||||
|
BEGIN
|
||||||
|
USE [{CUtils.EscapeStringCBracket(dbName)}]
|
||||||
|
IF EXISTS
|
||||||
|
( SELECT t.name, s.name
|
||||||
|
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||||
|
WHERE t.name = '{CUtils.EscapeStringSQuote(tableName)}'
|
||||||
|
AND s.name = '{CUtils.EscapeStringSQuote(schemaName)}'
|
||||||
|
)
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_name')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_id')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_description')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_framework')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_framework_version')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_version')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_creation_time')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='model_deployment_time')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='deployed_by')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('{twoPartsTableName}') AND NAME='run_id')
|
||||||
|
BEGIN
|
||||||
|
SELECT 1
|
||||||
|
END
|
||||||
|
ELSE
|
||||||
|
BEGIN
|
||||||
|
SELECT 0
|
||||||
|
END
|
||||||
|
END
|
||||||
|
ELSE
|
||||||
|
SELECT 1
|
||||||
|
END";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static string GetCreateModelTableQuery(string tableName, string schemaName)
|
||||||
|
{
|
||||||
|
return $@"
|
||||||
|
IF NOT EXISTS
|
||||||
|
( SELECT t.name, s.name
|
||||||
|
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||||
|
WHERE t.name = '{CUtils.EscapeStringSQuote(tableName)}'
|
||||||
|
AND s.name = '{CUtils.EscapeStringSQuote(schemaName)}'
|
||||||
|
)
|
||||||
|
BEGIN
|
||||||
|
CREATE TABLE {GetTwoPartsTableName(tableName, schemaName)} (
|
||||||
|
[model_id] [int] IDENTITY(1,1) NOT NULL,
|
||||||
|
[model_name] [varchar](256) NOT NULL,
|
||||||
|
[model_framework] [varchar](256) NULL,
|
||||||
|
[model_framework_version] [varchar](256) NULL,
|
||||||
|
[model] [varbinary](max) NOT NULL,
|
||||||
|
[model_version] [varchar](256) NULL,
|
||||||
|
[model_creation_time] [datetime2] NULL,
|
||||||
|
[model_deployment_time] [datetime2] NULL,
|
||||||
|
[deployed_by] [int] NULL,
|
||||||
|
[model_description] [varchar](256) NULL,
|
||||||
|
[run_id] [varchar](256) NULL,
|
||||||
|
CONSTRAINT [{CUtils.EscapeStringCBracket(tableName)}_models_pk] PRIMARY KEY CLUSTERED
|
||||||
|
(
|
||||||
|
[model_id] ASC
|
||||||
|
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||||
|
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||||
|
ALTER TABLE {GetTwoPartsTableName(tableName, schemaName)} ADD CONSTRAINT [{CUtils.EscapeStringCBracket(tableName)}_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
|
||||||
|
END
|
||||||
|
";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string GetInsertModelQuery(string tableName, string schemaName, ModelMetadata model, Dictionary<string, object> parameters)
|
||||||
|
{
|
||||||
|
string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName);
|
||||||
|
|
||||||
|
return $@"
|
||||||
|
INSERT INTO {twoPartsTableName}
|
||||||
|
(model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id, deployed_by)
|
||||||
|
VALUES (
|
||||||
|
{DatabaseUtils.AddStringParameterForInsert(model.ModelName ?? "")},
|
||||||
|
{DatabaseUtils.AddByteArrayParameterForInsert("Content", model.FilePath ?? "", parameters)},
|
||||||
|
{DatabaseUtils.AddStringParameterForInsert(model.Version ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForInsert(model.Description ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForInsert(model.Created)},
|
||||||
|
{DatabaseUtils.AddStringParameterForInsert(model.Framework ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForInsert(model.FrameworkVersion ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForInsert(model.RunId ?? "")},
|
||||||
|
USER_ID (Current_User)
|
||||||
|
)
|
||||||
|
";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string GetUpdateModelQuery(string tableName, string schemaName, ModelMetadata model, Dictionary<string, object> parameters)
|
||||||
|
{
|
||||||
|
string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName);
|
||||||
|
parameters.Add(ModelIdParameterName, model.Id);
|
||||||
|
|
||||||
|
return $@"
|
||||||
|
UPDATE {twoPartsTableName}
|
||||||
|
SET
|
||||||
|
{DatabaseUtils.AddStringParameterForUpdate("model_name", model.ModelName ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForUpdate("model_version", model.Version ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForUpdate("model_description", model.Description ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForUpdate("model_creation_time", model.Created)},
|
||||||
|
{DatabaseUtils.AddStringParameterForUpdate("model_framework", model.Framework ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForUpdate("model_framework_version", model.FrameworkVersion ?? "")},
|
||||||
|
{DatabaseUtils.AddStringParameterForUpdate("run_id", model.RunId ?? "")}
|
||||||
|
WHERE model_id = @{ModelIdParameterName}
|
||||||
|
|
||||||
|
";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string GetDeleteModelQuery(string tableName, string schemaName, int modelId, Dictionary<string, object> parameters)
|
||||||
|
{
|
||||||
|
string twoPartsTableName = GetTwoPartsTableName(tableName, schemaName);
|
||||||
|
parameters.Add(ModelIdParameterName, modelId);
|
||||||
|
|
||||||
|
return $@"
|
||||||
|
DELETE FROM {twoPartsTableName}
|
||||||
|
WHERE model_id = @{ModelIdParameterName}
|
||||||
|
";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static string GetSelectModelContentQuery(string dbName, string tableName, string schemaName, int modelId, Dictionary<string, object> parameters)
|
||||||
|
{
|
||||||
|
string threePartsTableName = GetThreePartsTableName(dbName, tableName, schemaName);
|
||||||
|
parameters.Add(ModelIdParameterName, modelId);
|
||||||
|
|
||||||
|
return $@"
|
||||||
|
SELECT model
|
||||||
|
FROM {threePartsTableName}
|
||||||
|
WHERE model_id = @{ModelIdParameterName}
|
||||||
|
";
|
||||||
|
}
|
||||||
|
|
||||||
|
private const string ModelIdParameterName = "ModelId";
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,7 +3,10 @@
|
|||||||
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
//
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Management;
|
||||||
using System;
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.IO;
|
||||||
|
|
||||||
namespace Microsoft.SqlTools.ServiceLayer.Utility
|
namespace Microsoft.SqlTools.ServiceLayer.Utility
|
||||||
{
|
{
|
||||||
@@ -22,5 +25,45 @@ namespace Microsoft.SqlTools.ServiceLayer.Utility
|
|||||||
string.Compare(databaseName, CommonConstants.ModelDatabaseName, StringComparison.OrdinalIgnoreCase) == 0 ||
|
string.Compare(databaseName, CommonConstants.ModelDatabaseName, StringComparison.OrdinalIgnoreCase) == 0 ||
|
||||||
string.Compare(databaseName, CommonConstants.TempDbDatabaseName, StringComparison.OrdinalIgnoreCase) == 0);
|
string.Compare(databaseName, CommonConstants.TempDbDatabaseName, StringComparison.OrdinalIgnoreCase) == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static string AddStringParameterForInsert(string paramValue)
|
||||||
|
{
|
||||||
|
string value = string.IsNullOrWhiteSpace(paramValue) ? paramValue : CUtils.EscapeStringSQuote(paramValue);
|
||||||
|
return $"'{value}'";
|
||||||
|
}
|
||||||
|
|
||||||
|
public static string AddStringParameterForUpdate(string columnName, string paramValue)
|
||||||
|
{
|
||||||
|
string value = string.IsNullOrWhiteSpace(paramValue) ? paramValue : CUtils.EscapeStringSQuote(paramValue);
|
||||||
|
return $"{columnName} = N'{value}'";
|
||||||
|
}
|
||||||
|
|
||||||
|
public static string AddByteArrayParameterForUpdate(string columnName, string paramName, string fileName, Dictionary<string, object> parameters)
|
||||||
|
{
|
||||||
|
byte[] contentBytes;
|
||||||
|
using (var stream = new FileStream(fileName, FileMode.Open, FileAccess.Read))
|
||||||
|
{
|
||||||
|
using (var reader = new BinaryReader(stream))
|
||||||
|
{
|
||||||
|
contentBytes = reader.ReadBytes((int)stream.Length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parameters.Add($"{paramName}", contentBytes);
|
||||||
|
return $"{columnName} = @{paramName}";
|
||||||
|
}
|
||||||
|
|
||||||
|
public static string AddByteArrayParameterForInsert(string paramName, string fileName, Dictionary<string, object> parameters)
|
||||||
|
{
|
||||||
|
byte[] contentBytes;
|
||||||
|
using (var stream = new FileStream(fileName, FileMode.Open, FileAccess.Read))
|
||||||
|
{
|
||||||
|
using (var reader = new BinaryReader(stream))
|
||||||
|
{
|
||||||
|
contentBytes = reader.ReadBytes((int)stream.Length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parameters.Add($"{paramName}", contentBytes);
|
||||||
|
return $"@{paramName}";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,326 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.SqlTools.Extensibility;
|
||||||
|
using Microsoft.SqlTools.Hosting.Protocol;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection.Contracts;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.ModelManagement;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Test.Common;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.UnitTests;
|
||||||
|
using Moq;
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Data;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using NUnit.Framework;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ModelManagement
|
||||||
|
{
|
||||||
|
public class ModelManagementServiceTests : ServiceTestBase
|
||||||
|
{
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyDeleteModelRequest()
|
||||||
|
{
|
||||||
|
DeleteModelRequestParams requestParams = new DeleteModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name",
|
||||||
|
ModelId = 1
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.DeleteModel(It.IsAny<IDbConnection>(), requestParams));
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyRequst<DeleteModelResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleDeleteModelRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
verify: (actual =>
|
||||||
|
{
|
||||||
|
Assert.NotNull(actual);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyImportModelRequest()
|
||||||
|
{
|
||||||
|
ImportModelRequestParams requestParams = new ImportModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name",
|
||||||
|
Model = new ModelMetadata()
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.ImportModel(It.IsAny<IDbConnection>(), requestParams));
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyRequst<ImportModelResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleModelImportRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
verify: (actual =>
|
||||||
|
{
|
||||||
|
Assert.NotNull(actual);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyUpdateModelRequest()
|
||||||
|
{
|
||||||
|
UpdateModelRequestParams requestParams = new UpdateModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name",
|
||||||
|
Model = new ModelMetadata()
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.UpdateModel(It.IsAny<IDbConnection>(), requestParams));
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyRequst<UpdateModelResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleUpdateModelRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
verify: (actual =>
|
||||||
|
{
|
||||||
|
Assert.NotNull(actual);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyDownloadModelRequest()
|
||||||
|
{
|
||||||
|
DownloadModelRequestParams requestParams = new DownloadModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name",
|
||||||
|
ModelId = 1
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.DownloadModel(It.IsAny<IDbConnection>(), requestParams)).Returns(() => "file path");
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyRequst<DownloadModelResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleDownloadModelRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
verify: (actual =>
|
||||||
|
{
|
||||||
|
Assert.NotNull(actual);
|
||||||
|
Assert.AreEqual(actual.FilePath, "file path");
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyModelTableRequest()
|
||||||
|
{
|
||||||
|
VerifyModelTableRequestParams requestParams = new VerifyModelTableRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name"
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.VerifyImportTable(It.IsAny<IDbConnection>(), requestParams)).Returns(() => true);
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyRequst<VerifyModelTableResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleVerifyModelTableRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
verify: (actual =>
|
||||||
|
{
|
||||||
|
Assert.NotNull(actual);
|
||||||
|
Assert.AreEqual(actual.Verified, true);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyConfigureModelTableRequest()
|
||||||
|
{
|
||||||
|
ConfigureModelTableRequestParams requestParams = new ConfigureModelTableRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name"
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.ConfigureImportTable(It.IsAny<IDbConnection>(), requestParams));
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyRequst<ConfigureModelTableResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleConfigureModelTableRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
verify: (actual =>
|
||||||
|
{
|
||||||
|
Assert.NotNull(actual);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyGetModelRequest()
|
||||||
|
{
|
||||||
|
GetModelsRequestParams requestParams = new GetModelsRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name",
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.GetModels(It.IsAny<IDbConnection>(), requestParams)).Returns(() => new List<ModelMetadata> { new ModelMetadata() });
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyRequst<GetModelsResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleGetModelsRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
},
|
||||||
|
verify: (actual =>
|
||||||
|
{
|
||||||
|
Assert.NotNull(actual);
|
||||||
|
Assert.True(actual.Models.Count == 1);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyRequestFailedResponse()
|
||||||
|
{
|
||||||
|
DeleteModelRequestParams requestParams = new DeleteModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name",
|
||||||
|
ModelId = 1
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.DeleteModel(It.IsAny<IDbConnection>(), requestParams)).Throws(new ApplicationException("error"));
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyError<DeleteModelResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = connectionUrl;
|
||||||
|
await service.HandleDeleteModelRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public async Task VerifyInvalidConnectionResponse()
|
||||||
|
{
|
||||||
|
DeleteModelRequestParams requestParams = new DeleteModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = "db name",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = "table name",
|
||||||
|
ModelId = 1
|
||||||
|
};
|
||||||
|
Mock<ModelOperations> operations = new Mock<ModelOperations>();
|
||||||
|
operations.Setup(x => x.DeleteModel(It.IsAny<IDbConnection>(), requestParams)).Throws(new ApplicationException("error"));
|
||||||
|
ModelManagementService service = new ModelManagementService()
|
||||||
|
{
|
||||||
|
ModelOperations = operations.Object
|
||||||
|
};
|
||||||
|
|
||||||
|
await VerifyError<DeleteModelResponseParams>(
|
||||||
|
test: async (requestContext, connectionUrl) =>
|
||||||
|
{
|
||||||
|
requestParams.OwnerUri = "Invalid connection uri";
|
||||||
|
await service.HandleDeleteModelRequest(requestParams, requestContext);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task VerifyRequst<T>(Func<RequestContext<T>, string, Task<T>> test, Action<T> verify)
|
||||||
|
{
|
||||||
|
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
|
||||||
|
{
|
||||||
|
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
|
||||||
|
await RunAndVerify<T>(
|
||||||
|
test: (requestContext) => test(requestContext, queryTempFile.FilePath),
|
||||||
|
verify: verify);
|
||||||
|
|
||||||
|
ModelManagementService.Instance.ConnectionServiceInstance.Disconnect(new DisconnectParams
|
||||||
|
{
|
||||||
|
OwnerUri = queryTempFile.FilePath,
|
||||||
|
Type = ServiceLayer.Connection.ConnectionType.Default
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task VerifyError<T>(Func<RequestContext<T>, string, Task<T>> test)
|
||||||
|
{
|
||||||
|
using (SelfCleaningTempFile queryTempFile = new SelfCleaningTempFile())
|
||||||
|
{
|
||||||
|
var connectionResult = await LiveConnectionHelper.InitLiveConnectionInfoAsync("master", queryTempFile.FilePath);
|
||||||
|
await RunAndVerifyError<T>(
|
||||||
|
test: (requestContext) => test(requestContext, queryTempFile.FilePath));
|
||||||
|
|
||||||
|
ModelManagementService.Instance.ConnectionServiceInstance.Disconnect(new DisconnectParams
|
||||||
|
{
|
||||||
|
OwnerUri = queryTempFile.FilePath,
|
||||||
|
Type = ServiceLayer.Connection.ConnectionType.Default
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override RegisteredServiceProvider CreateServiceProviderWithMinServices()
|
||||||
|
{
|
||||||
|
return base.CreateProvider();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,350 @@
|
|||||||
|
//
|
||||||
|
// Copyright (c) Microsoft. All rights reserved.
|
||||||
|
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
|
||||||
|
//
|
||||||
|
|
||||||
|
using Microsoft.Data.SqlClient;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Connection;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.IntegrationTests.Utility;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.ModelManagement;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.ModelManagement.Contracts;
|
||||||
|
using Microsoft.SqlTools.ServiceLayer.Test.Common;
|
||||||
|
using NUnit.Framework;
|
||||||
|
using System;
|
||||||
|
using System.Data;
|
||||||
|
using System.IO;
|
||||||
|
using System.Linq;
|
||||||
|
|
||||||
|
namespace Microsoft.SqlTools.ServiceLayer.IntegrationTests.ModelManagement
|
||||||
|
{
|
||||||
|
public class ModelOperationsTests
|
||||||
|
{
|
||||||
|
[Test]
|
||||||
|
public void VerifyImportTableShouldReturnTrueGivenNoTable()
|
||||||
|
{
|
||||||
|
bool expected = true;
|
||||||
|
bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ModelRequestBase request = new ModelRequestBase
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName
|
||||||
|
};
|
||||||
|
return modelOperations.VerifyImportTable(dbConnection, request);
|
||||||
|
});
|
||||||
|
|
||||||
|
Assert.AreEqual(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void ImportModelShouldImportSuccessfullyGivenValidModel()
|
||||||
|
{
|
||||||
|
string modelFilePath = Path.GetTempFileName();
|
||||||
|
File.WriteAllText(modelFilePath, "Test model");
|
||||||
|
ModelMetadata expected = new ModelMetadata
|
||||||
|
{
|
||||||
|
FilePath = modelFilePath,
|
||||||
|
Created = DateTime.Now.ToString(),
|
||||||
|
Description = "model description",
|
||||||
|
ModelName = "model name",
|
||||||
|
RunId = "run id",
|
||||||
|
Version = "1.2",
|
||||||
|
Framework = "ONNX",
|
||||||
|
FrameworkVersion = "1.3",
|
||||||
|
};
|
||||||
|
ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
|
||||||
|
ImportModelRequestParams request = new ImportModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName,
|
||||||
|
Model = expected
|
||||||
|
};
|
||||||
|
modelOperations.ConfigureImportTable(dbConnection, request);
|
||||||
|
modelOperations.ImportModel(dbConnection, request);
|
||||||
|
var models = modelOperations.GetModels(dbConnection, request);
|
||||||
|
return models.FirstOrDefault(x => x.ModelName == expected.ModelName);
|
||||||
|
});
|
||||||
|
|
||||||
|
Assert.IsNotNull(actual);
|
||||||
|
Assert.AreEqual(expected.RunId, actual.RunId);
|
||||||
|
Assert.AreEqual(expected.Description, actual.Description);
|
||||||
|
Assert.AreEqual(expected.Framework, actual.Framework);
|
||||||
|
Assert.AreEqual(expected.Version, actual.Version);
|
||||||
|
Assert.AreEqual(expected.FrameworkVersion, actual.FrameworkVersion);
|
||||||
|
Assert.IsFalse(string.IsNullOrWhiteSpace(actual.DeploymentTime));
|
||||||
|
Assert.IsFalse(string.IsNullOrWhiteSpace(actual.DeployedBy));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void DownloadModelShouldDownloadSuccessfullyGivenValidModel()
|
||||||
|
{
|
||||||
|
string expected = "Test model";
|
||||||
|
|
||||||
|
string modelFilePath = Path.GetTempFileName();
|
||||||
|
File.WriteAllText(modelFilePath, expected);
|
||||||
|
ModelMetadata model = new ModelMetadata
|
||||||
|
{
|
||||||
|
FilePath = modelFilePath,
|
||||||
|
Created = DateTime.Now.ToString(),
|
||||||
|
Description = "model description",
|
||||||
|
ModelName = "model name",
|
||||||
|
RunId = "run id",
|
||||||
|
Version = "1.2",
|
||||||
|
Framework = "ONNX",
|
||||||
|
FrameworkVersion = "1.3",
|
||||||
|
};
|
||||||
|
string actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ImportModelRequestParams request = new ImportModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName,
|
||||||
|
Model = model
|
||||||
|
};
|
||||||
|
modelOperations.ConfigureImportTable(dbConnection, request);
|
||||||
|
modelOperations.ImportModel(dbConnection, request);
|
||||||
|
var models = modelOperations.GetModels(dbConnection, request);
|
||||||
|
var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
|
||||||
|
Assert.IsNotNull(importedModel);
|
||||||
|
DownloadModelRequestParams downloadRequest = new DownloadModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName,
|
||||||
|
ModelId = importedModel.Id
|
||||||
|
};
|
||||||
|
string downloadedFile = modelOperations.DownloadModel(dbConnection, downloadRequest);
|
||||||
|
return File.ReadAllText(downloadedFile);
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
Assert.IsNotNull(actual);
|
||||||
|
Assert.AreEqual(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void UpdateModelShouldUpdateSuccessfullyGivenValidModel()
|
||||||
|
{
|
||||||
|
string modelFilePath = Path.GetTempFileName();
|
||||||
|
File.WriteAllText(modelFilePath, "Test");
|
||||||
|
ModelMetadata model = new ModelMetadata
|
||||||
|
{
|
||||||
|
FilePath = modelFilePath,
|
||||||
|
Created = DateTime.Now.ToString(),
|
||||||
|
Description = "model description",
|
||||||
|
ModelName = "model name",
|
||||||
|
RunId = "run id",
|
||||||
|
Version = "1.2",
|
||||||
|
Framework = "ONNX",
|
||||||
|
FrameworkVersion = "1.3",
|
||||||
|
};
|
||||||
|
ModelMetadata expected = model;
|
||||||
|
|
||||||
|
ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ImportModelRequestParams request = new ImportModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName,
|
||||||
|
Model = model
|
||||||
|
};
|
||||||
|
modelOperations.ConfigureImportTable(dbConnection, request);
|
||||||
|
modelOperations.ImportModel(dbConnection, request);
|
||||||
|
var models = modelOperations.GetModels(dbConnection, request);
|
||||||
|
var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
|
||||||
|
Assert.IsNotNull(importedModel);
|
||||||
|
|
||||||
|
UpdateModelRequestParams updateRequest = new UpdateModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName,
|
||||||
|
Model = new ModelMetadata
|
||||||
|
{
|
||||||
|
Description = request.Model.Description + "updated",
|
||||||
|
ModelName = request.Model.ModelName + "updated",
|
||||||
|
Version = request.Model.Version + "updated",
|
||||||
|
Framework = request.Model.Framework + "updated",
|
||||||
|
RunId = request.Model.RunId + "updated",
|
||||||
|
FrameworkVersion = request.Model.FrameworkVersion + "updated",
|
||||||
|
Id = importedModel.Id
|
||||||
|
}
|
||||||
|
};
|
||||||
|
modelOperations.UpdateModel(dbConnection, updateRequest);
|
||||||
|
models = modelOperations.GetModels(dbConnection, request);
|
||||||
|
var updatedModel = models.FirstOrDefault(x => x.ModelName == updateRequest.Model.ModelName);
|
||||||
|
Assert.IsNotNull(updatedModel);
|
||||||
|
return updatedModel;
|
||||||
|
});
|
||||||
|
|
||||||
|
Assert.IsNotNull(actual);
|
||||||
|
Assert.AreEqual(expected.Description + "updated", actual.Description);
|
||||||
|
Assert.AreEqual(expected.ModelName + "updated", actual.ModelName);
|
||||||
|
Assert.AreEqual(expected.Version + "updated", actual.Version);
|
||||||
|
Assert.AreEqual(expected.Framework + "updated", actual.Framework);
|
||||||
|
Assert.AreEqual(expected.FrameworkVersion + "updated", actual.FrameworkVersion);
|
||||||
|
Assert.AreEqual(expected.RunId + "updated", actual.RunId);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void DeleteModelShouldDeleteSuccessfullyGivenValidModel()
|
||||||
|
{
|
||||||
|
string modelFilePath = Path.GetTempFileName();
|
||||||
|
File.WriteAllText(modelFilePath, "Test");
|
||||||
|
ModelMetadata model = new ModelMetadata
|
||||||
|
{
|
||||||
|
FilePath = modelFilePath,
|
||||||
|
Created = DateTime.Now.ToString(),
|
||||||
|
Description = "model description",
|
||||||
|
ModelName = "model name",
|
||||||
|
RunId = "run id",
|
||||||
|
Version = "1.2",
|
||||||
|
Framework = "ONNX",
|
||||||
|
FrameworkVersion = "1.3",
|
||||||
|
};
|
||||||
|
ModelMetadata expected = model;
|
||||||
|
|
||||||
|
ModelMetadata actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ImportModelRequestParams request = new ImportModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName,
|
||||||
|
Model = model
|
||||||
|
};
|
||||||
|
modelOperations.ConfigureImportTable(dbConnection, request);
|
||||||
|
modelOperations.ImportModel(dbConnection, request);
|
||||||
|
var models = modelOperations.GetModels(dbConnection, request);
|
||||||
|
var importedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
|
||||||
|
Assert.IsNotNull(importedModel);
|
||||||
|
|
||||||
|
DeleteModelRequestParams deleteRequest = new DeleteModelRequestParams
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName,
|
||||||
|
ModelId = importedModel.Id
|
||||||
|
};
|
||||||
|
modelOperations.DeleteModel(dbConnection, deleteRequest);
|
||||||
|
models = modelOperations.GetModels(dbConnection, request);
|
||||||
|
var updatedModel = models.FirstOrDefault(x => x.ModelName == model.ModelName);
|
||||||
|
return updatedModel;
|
||||||
|
});
|
||||||
|
|
||||||
|
Assert.IsNull(actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void VerifyImportTableShouldReturnTrueGivenTableCreatdByTheService()
|
||||||
|
{
|
||||||
|
bool expected = true;
|
||||||
|
bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ModelRequestBase request = new ModelRequestBase
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName
|
||||||
|
};
|
||||||
|
modelOperations.ConfigureImportTable(dbConnection, request);
|
||||||
|
return modelOperations.VerifyImportTable(dbConnection, request);
|
||||||
|
});
|
||||||
|
|
||||||
|
Assert.AreEqual(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void VerifyImportTableShouldReturnTrueGivenTableNameThatNeedsEscaping()
|
||||||
|
{
|
||||||
|
bool expected = true;
|
||||||
|
bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ModelRequestBase request = new ModelRequestBase
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName
|
||||||
|
};
|
||||||
|
modelOperations.ConfigureImportTable(dbConnection, request);
|
||||||
|
return modelOperations.VerifyImportTable(dbConnection, request);
|
||||||
|
}, null, "models[]'");
|
||||||
|
|
||||||
|
Assert.AreEqual(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void VerifyImportTableShouldReturnFalseGivenInvalidDbName()
|
||||||
|
{
|
||||||
|
Assert.Throws(typeof(SqlException), () => VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ModelRequestBase request = new ModelRequestBase
|
||||||
|
{
|
||||||
|
DatabaseName = "invalidDb",
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName
|
||||||
|
};
|
||||||
|
modelOperations.ConfigureImportTable(dbConnection, request);
|
||||||
|
return modelOperations.VerifyImportTable(dbConnection, request);
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Test]
|
||||||
|
public void VerifyImportTableShouldReturnFalseGivenInvalidTable()
|
||||||
|
{
|
||||||
|
bool expected = false;
|
||||||
|
bool actual = VerifyModelOperation((dbConnection, databaseName , tableName) =>
|
||||||
|
{
|
||||||
|
dbConnection.ChangeDatabase(databaseName);
|
||||||
|
using (IDbCommand command = dbConnection.CreateCommand())
|
||||||
|
{
|
||||||
|
command.CommandText = $"Create Table {tableName} (Id int, name varchar(10))";
|
||||||
|
command.ExecuteNonQuery();
|
||||||
|
}
|
||||||
|
dbConnection.ChangeDatabase("master");
|
||||||
|
|
||||||
|
ModelOperations modelOperations = new ModelOperations();
|
||||||
|
ModelRequestBase request = new ModelRequestBase
|
||||||
|
{
|
||||||
|
DatabaseName = databaseName,
|
||||||
|
SchemaName = "dbo",
|
||||||
|
TableName = tableName
|
||||||
|
};
|
||||||
|
return modelOperations.VerifyImportTable(dbConnection, request);
|
||||||
|
});
|
||||||
|
|
||||||
|
Assert.AreEqual(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
private T VerifyModelOperation<T>(Func<IDbConnection, string, string, T> operation, string dbName = null, string tbName = null)
|
||||||
|
{
|
||||||
|
string databaseName = dbName ?? "testModels_" + new Random().Next(10000000, 99999999);
|
||||||
|
string tableName = tbName ?? "models";
|
||||||
|
using (SqlTestDb testDb = SqlTestDb.CreateNew(TestServerType.OnPrem, false, databaseName))
|
||||||
|
{
|
||||||
|
var liveConnection = LiveConnectionHelper.InitLiveConnectionInfo(databaseName);
|
||||||
|
ConnectionInfo connInfo = liveConnection.ConnectionInfo;
|
||||||
|
IDbConnection dbConnection = ConnectionService.OpenSqlConnection(connInfo);
|
||||||
|
dbConnection.ChangeDatabase("master");
|
||||||
|
|
||||||
|
T result = operation(dbConnection, databaseName, tableName);
|
||||||
|
Assert.AreEqual(dbConnection.Database, "master");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user