mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-01-25 09:35:37 -05:00
Machine Learning - Supporting multiple model import (#9869)
* Machine Learning Extension - Changed the deploy wizard to deploy multiple files
This commit is contained in:
@@ -86,22 +86,23 @@ export class DeployedModelService {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let currentModels = await this.getDeployedModels();
|
||||
await this._modelClient.deployModel(connection, filePath);
|
||||
let updatedModels = await this.getDeployedModels();
|
||||
if (details && updatedModels.length >= currentModels.length + 1) {
|
||||
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
|
||||
const addedModel = updatedModels[updatedModels.length - 1];
|
||||
addedModel.title = details.title;
|
||||
addedModel.description = details.description;
|
||||
addedModel.version = details.version;
|
||||
const updatedModel = await this.updateModel(addedModel);
|
||||
if (!updatedModel) {
|
||||
throw Error(constants.updateModelFailedError);
|
||||
}
|
||||
const content = await utils.readFileInHex(filePath);
|
||||
const fileName = details?.fileName || utils.getFileName(filePath);
|
||||
let modelToAdd: RegisteredModel = {
|
||||
id: 0,
|
||||
artifactName: fileName,
|
||||
content: content,
|
||||
title: details?.title || fileName,
|
||||
description: details?.description,
|
||||
version: details?.version
|
||||
};
|
||||
await this._queryRunner.safeRunQuery(connection, this.getInsertModelQuery(connection.databaseName, modelToAdd));
|
||||
|
||||
} else {
|
||||
throw Error(constants.importModelFailedError);
|
||||
let updatedModels = await this.getDeployedModels();
|
||||
if (updatedModels.length < currentModels.length + 1) {
|
||||
throw Error(constants.importModelFailedError(details?.title, filePath));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||
@@ -115,20 +116,6 @@ export class DeployedModelService {
|
||||
};
|
||||
}
|
||||
|
||||
private async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
const query = this.getUpdateModelQuery(connection.databaseName, model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result?.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
}
|
||||
}
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
@@ -190,7 +177,7 @@ export class DeployedModelService {
|
||||
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
|
||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[artifact_name] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NULL,
|
||||
[artifact_content] [varbinary](max) NOT NULL,
|
||||
[artifact_initial_size] [bigint] NULL,
|
||||
[name] [varchar](256) NULL,
|
||||
@@ -207,20 +194,24 @@ export class DeployedModelService {
|
||||
`;
|
||||
}
|
||||
|
||||
public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
public getInsertModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
let updateScript = `
|
||||
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||
SET
|
||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||
WHERE artifact_id = ${model.id}`;
|
||||
Insert into ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||
(artifact_name, group_path, artifact_content, name, version, description)
|
||||
values (
|
||||
'${utils.doubleEscapeSingleQuotes(model.artifactName || '')}',
|
||||
'ADS',
|
||||
${utils.doubleEscapeSingleQuotes(model.content || '')},
|
||||
'${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
'${utils.doubleEscapeSingleQuotes(model.description || '')}')
|
||||
`;
|
||||
|
||||
return `
|
||||
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_id = ${model.id};
|
||||
WHERE artifact_id = SCOPE_IDENTITY();
|
||||
`;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user