mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-13 19:48:37 -05:00
Machine Learning Extension - Model details (#9377)
* Machine Learning Services Extension - adding model details
This commit is contained in:
@@ -21,6 +21,7 @@ import { HttpClient } from '../common/httpClient';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as utils from '../common/utils';
|
||||
|
||||
/**
|
||||
* Azure Model Service
|
||||
@@ -109,7 +110,7 @@ export class AzureModelRegistryService {
|
||||
try {
|
||||
const downloadUrls = await this.getAssetArtifactsDownloadLinks(account, subscription, resourceGroup, workspace, model, tenant);
|
||||
if (downloadUrls && downloadUrls.length > 0) {
|
||||
downloadedFilePath = await this.downloadArtifact(downloadUrls[0]);
|
||||
downloadedFilePath = await this.execDownloadArtifactTask(downloadUrls[0]);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
@@ -122,29 +123,15 @@ export class AzureModelRegistryService {
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
*/
|
||||
public async downloadArtifact(downloadUrl: string): Promise<string> {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
let msgTaskName = constants.downloadModelMsgTaskName;
|
||||
this._apiWrapper.startBackgroundOperation({
|
||||
displayName: msgTaskName,
|
||||
description: msgTaskName,
|
||||
isCancelable: false,
|
||||
operation: async op => {
|
||||
let tempFilePath: string = '';
|
||||
try {
|
||||
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await this._httpClient.download(downloadUrl, tempFilePath, op, this._outputChannel);
|
||||
public async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
|
||||
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
|
||||
return results && results.length > 0 ? results[0] : constants.noResultError;
|
||||
}
|
||||
|
||||
op.updateStatus(azdata.TaskStatus.Succeeded);
|
||||
resolve(tempFilePath);
|
||||
} catch (error) {
|
||||
let errorMsg = constants.installDependenciesError(error ? error.message : '');
|
||||
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
|
||||
reject(errorMsg);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
private async downloadArtifact(downloadUrl: string): Promise<string> {
|
||||
let tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await this._httpClient.download(downloadUrl, tempFilePath, this._outputChannel);
|
||||
return tempFilePath;
|
||||
}
|
||||
|
||||
private async fetchWorkspaces(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
|
||||
|
||||
@@ -49,8 +49,12 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
||||
* An interface representing registered model
|
||||
*/
|
||||
export interface RegisteredModel {
|
||||
id: number,
|
||||
name: string
|
||||
id?: number,
|
||||
artifactName?: string,
|
||||
title?: string,
|
||||
created?: string,
|
||||
version?: string
|
||||
description?: string
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -9,6 +9,9 @@ import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as utils from '../common/utils';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
/**
|
||||
* Service to import model to database
|
||||
@@ -18,13 +21,22 @@ export class ModelImporter {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config) {
|
||||
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
|
||||
}
|
||||
|
||||
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
await this.installDependencies();
|
||||
await this.executeScripts(connection, modelFolderPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for model importer
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
|
||||
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
|
||||
}
|
||||
|
||||
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
|
||||
const parts = modelFolderPath.split('\\');
|
||||
@@ -36,7 +48,7 @@ export class ModelImporter {
|
||||
let server = connection.serverName;
|
||||
|
||||
const experimentId = `ads_ml_experiment_${UUID.generateUuid()}`;
|
||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}` : '';
|
||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
|
||||
let scripts: string[] = [
|
||||
'import mlflow.onnx',
|
||||
'import onnx',
|
||||
@@ -44,7 +56,7 @@ export class ModelImporter {
|
||||
`onx = onnx.load("${modelFolderPath}")`,
|
||||
'client = MlflowClient()',
|
||||
`exp_name = "${experimentId}"`,
|
||||
`db_uri_artifact = "mssql+pyodbc://${credential}@${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server"`,
|
||||
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
|
||||
'client.create_experiment(exp_name, artifact_location=db_uri_artifact)',
|
||||
'mlflow.set_experiment(exp_name)',
|
||||
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import { Config } from '../configurations/config';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { RegisteredModel } from './interfaces';
|
||||
import { ModelImporter } from './modelImporter';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
/**
|
||||
* Service to registered models
|
||||
@@ -33,20 +35,57 @@ export class RegisteredModelService {
|
||||
let result = await this.runRegisteredModelsListQuery(connection);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push({
|
||||
id: +row[0].displayValue,
|
||||
name: row[1].displayValue
|
||||
});
|
||||
list.push(this.loadModelData(row));
|
||||
});
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string) {
|
||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||
return {
|
||||
id: +row[0].displayValue,
|
||||
artifactName: row[1].displayValue,
|
||||
title: row[2].displayValue,
|
||||
description: row[3].displayValue,
|
||||
version: row[4].displayValue,
|
||||
created: row[5].displayValue
|
||||
};
|
||||
}
|
||||
|
||||
public async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
let result = await this.runUpdateModelQuery(connection, model);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
}
|
||||
}
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let currentModels = await this.getRegisteredModels();
|
||||
await this._modelImporter.registerModel(connection, filePath);
|
||||
let updatedModels = await this.getRegisteredModels();
|
||||
if (details && updatedModels.length >= currentModels.length + 1) {
|
||||
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
|
||||
const addedModel = updatedModels[updatedModels.length - 1];
|
||||
addedModel.title = details.title;
|
||||
addedModel.description = details.description;
|
||||
addedModel.version = details.version;
|
||||
const updatedModel = await this.updateModel(addedModel);
|
||||
if (!updatedModel) {
|
||||
throw Error(constants.updateModelFailedError);
|
||||
}
|
||||
|
||||
} else {
|
||||
throw Error(constants.importModelFailedError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,22 +95,91 @@ export class RegisteredModelService {
|
||||
|
||||
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
|
||||
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private registeredModelsQuery(databaseName: string, tableName: string) {
|
||||
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
|
||||
return `
|
||||
IF (EXISTS (SELECT name
|
||||
FROM master.dbo.sysdatabases
|
||||
WHERE ('[' + name + ']' = '${databaseName}'
|
||||
OR name = '${databaseName}')))
|
||||
${this.configureTable(databaseName, tableName)}
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM [${escapedDbName}].dbo.[${escapedTableName}]
|
||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||
Order by artifact_id
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the table and adds extra columns (name, description, version) if doesn't already exist.
|
||||
* Note: this code is temporary and will be removed weh the table supports the required schema
|
||||
* @param databaseName
|
||||
* @param tableName
|
||||
*/
|
||||
private configureTable(databaseName: string, tableName: string): string {
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
IF EXISTS
|
||||
( SELECT [name]
|
||||
FROM sys.tables
|
||||
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
)
|
||||
BEGIN
|
||||
SELECT artifact_id, artifact_name, group_path, artifact_initial_size from ${databaseName}.${tableName}
|
||||
WHERE artifact_name like '%.onnx'
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
|
||||
BEGIN
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
END
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
|
||||
END
|
||||
`;
|
||||
}
|
||||
|
||||
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
|
||||
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
UPDATE ${escapedTableName}
|
||||
SET
|
||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||
WHERE artifact_id = ${model.id};
|
||||
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user