Machine Learning Extension - Model details (#9377)

* Machine Learning Services Extension - adding model details
This commit is contained in:
Leila Lali
2020-03-02 12:47:09 -08:00
committed by GitHub
parent c1f6a67829
commit b5b65117a7
30 changed files with 852 additions and 224 deletions

View File

@@ -21,6 +21,7 @@ import { HttpClient } from '../common/httpClient';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as path from 'path';
import * as os from 'os';
import * as utils from '../common/utils';
/**
* Azure Model Service
@@ -109,7 +110,7 @@ export class AzureModelRegistryService {
try {
const downloadUrls = await this.getAssetArtifactsDownloadLinks(account, subscription, resourceGroup, workspace, model, tenant);
if (downloadUrls && downloadUrls.length > 0) {
downloadedFilePath = await this.downloadArtifact(downloadUrls[0]);
downloadedFilePath = await this.execDownloadArtifactTask(downloadUrls[0]);
}
} catch (error) {
@@ -122,29 +123,15 @@ export class AzureModelRegistryService {
/**
* Installs dependencies for the extension
*/
public async downloadArtifact(downloadUrl: string): Promise<string> {
return new Promise<string>((resolve, reject) => {
let msgTaskName = constants.downloadModelMsgTaskName;
this._apiWrapper.startBackgroundOperation({
displayName: msgTaskName,
description: msgTaskName,
isCancelable: false,
operation: async op => {
let tempFilePath: string = '';
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await this._httpClient.download(downloadUrl, tempFilePath, op, this._outputChannel);
public async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
return results && results.length > 0 ? results[0] : constants.noResultError;
}
op.updateStatus(azdata.TaskStatus.Succeeded);
resolve(tempFilePath);
} catch (error) {
let errorMsg = constants.installDependenciesError(error ? error.message : '');
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
reject(errorMsg);
}
}
});
});
private async downloadArtifact(downloadUrl: string): Promise<string> {
let tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await this._httpClient.download(downloadUrl, tempFilePath, this._outputChannel);
return tempFilePath;
}
private async fetchWorkspaces(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {

View File

@@ -49,8 +49,12 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
* An interface representing registered model
*/
export interface RegisteredModel {
id: number,
name: string
id?: number,
artifactName?: string,
title?: string,
created?: string,
version?: string
description?: string
}
/**

View File

@@ -9,6 +9,9 @@ import { ApiWrapper } from '../common/apiWrapper';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as utils from '../common/utils';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
/**
* Service to import model to database
@@ -18,13 +21,22 @@ export class ModelImporter {
/**
*
*/
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config) {
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
}
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
await this.installDependencies();
await this.executeScripts(connection, modelFolderPath);
}
/**
* Installs dependencies for model importer
*/
public async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
}
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
const parts = modelFolderPath.split('\\');
@@ -36,7 +48,7 @@ export class ModelImporter {
let server = connection.serverName;
const experimentId = `ads_ml_experiment_${UUID.generateUuid()}`;
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}` : '';
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
let scripts: string[] = [
'import mlflow.onnx',
'import onnx',
@@ -44,7 +56,7 @@ export class ModelImporter {
`onx = onnx.load("${modelFolderPath}")`,
'client = MlflowClient()',
`exp_name = "${experimentId}"`,
`db_uri_artifact = "mssql+pyodbc://${credential}@${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server"`,
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
'client.create_experiment(exp_name, artifact_location=db_uri_artifact)',
'mlflow.set_experiment(exp_name)',
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'

View File

@@ -6,10 +6,12 @@
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
import { RegisteredModel } from './interfaces';
import { ModelImporter } from './modelImporter';
import * as constants from '../common/constants';
/**
* Service to registered models
@@ -33,20 +35,57 @@ export class RegisteredModelService {
let result = await this.runRegisteredModelsListQuery(connection);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push({
id: +row[0].displayValue,
name: row[1].displayValue
});
list.push(this.loadModelData(row));
});
}
}
return list;
}
public async registerLocalModel(filePath: string) {
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
return {
id: +row[0].displayValue,
artifactName: row[1].displayValue,
title: row[2].displayValue,
description: row[3].displayValue,
version: row[4].displayValue,
created: row[5].displayValue
};
}
public async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
let connection = await this.getCurrentConnection();
let updatedModel: RegisteredModel | undefined = undefined;
if (connection) {
let result = await this.runUpdateModelQuery(connection, model);
if (result && result.rows && result.rows.length > 0) {
const row = result.rows[0];
updatedModel = this.loadModelData(row);
}
}
return updatedModel;
}
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
let connection = await this.getCurrentConnection();
if (connection) {
let currentModels = await this.getRegisteredModels();
await this._modelImporter.registerModel(connection, filePath);
let updatedModels = await this.getRegisteredModels();
if (details && updatedModels.length >= currentModels.length + 1) {
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
const addedModel = updatedModels[updatedModels.length - 1];
addedModel.title = details.title;
addedModel.description = details.description;
addedModel.version = details.version;
const updatedModel = await this.updateModel(addedModel);
if (!updatedModel) {
throw Error(constants.updateModelFailedError);
}
} else {
throw Error(constants.importModelFailedError);
}
}
}
@@ -56,22 +95,91 @@ export class RegisteredModelService {
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
} catch {
return undefined;
}
}
private registeredModelsQuery(databaseName: string, tableName: string) {
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
} catch {
return undefined;
}
}
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
return `
IF (EXISTS (SELECT name
FROM master.dbo.sysdatabases
WHERE ('[' + name + ']' = '${databaseName}'
OR name = '${databaseName}')))
${this.configureTable(databaseName, tableName)}
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created
FROM [${escapedDbName}].dbo.[${escapedTableName}]
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
}
/**
* Update the table and adds extra columns (name, description, version) if doesn't already exist.
* Note: this code is temporary and will be removed weh the table supports the required schema
* @param databaseName
* @param tableName
*/
private configureTable(databaseName: string, tableName: string): string {
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
return `
USE [${escapedDbName}]
IF EXISTS
( SELECT [name]
FROM sys.tables
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
)
BEGIN
SELECT artifact_id, artifact_name, group_path, artifact_initial_size from ${databaseName}.${tableName}
WHERE artifact_name like '%.onnx'
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
BEGIN
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
END
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
END
`;
}
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
return `
USE [${escapedDbName}]
UPDATE ${escapedTableName}
SET
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
WHERE artifact_id = ${model.id};
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
WHERE artifact_id = ${model.id};
`;
}
}