Machine Learning Services Extension - Predict wizard (#9450)

*MLS extension - Added predict wizard
This commit is contained in:
Leila Lali
2020-03-09 15:40:05 -07:00
committed by GitHub
parent b017634431
commit 3be3563b0d
37 changed files with 1501 additions and 219 deletions

View File

@@ -48,13 +48,19 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
/**
* An interface representing registered model
*/
export interface RegisteredModel {
id?: number,
artifactName?: string,
title?: string,
created?: string,
version?: string
description?: string
export interface RegisteredModel extends RegisteredModelDetails {
id: number;
artifactName: string;
}
/**
* An interface representing registered model
*/
export interface RegisteredModelDetails {
title: string;
created?: string;
version?: string;
description?: string;
}
/**

View File

@@ -12,6 +12,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as utils from '../common/utils';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
import * as os from 'os';
/**
* Service to import model to database
@@ -39,8 +40,8 @@ export class ModelImporter {
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
const parts = modelFolderPath.split('\\');
modelFolderPath = parts.join('/');
let home = utils.makeLinuxPath(os.homedir());
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
@@ -51,9 +52,12 @@ export class ModelImporter {
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
let scripts: string[] = [
'import mlflow.onnx',
`tracking_uri = "file://${home}/mlruns"`,
'print(tracking_uri)',
'import onnx',
'from mlflow.tracking.client import MlflowClient',
`onx = onnx.load("${modelFolderPath}")`,
`mlflow.set_tracking_uri(tracking_uri)`,
'client = MlflowClient()',
`exp_name = "${experimentId}"`,
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,

View File

@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
import { RegisteredModel } from './interfaces';
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
import { ModelImporter } from './modelImporter';
import * as constants from '../common/constants';
@@ -32,7 +32,10 @@ export class RegisteredModelService {
let connection = await this.getCurrentConnection();
let list: RegisteredModel[] = [];
if (connection) {
let result = await this.runRegisteredModelsListQuery(connection);
let query = this.getConfigureQuery(connection.databaseName);
await this._queryRunner.safeRunQuery(connection, query);
query = this.registeredModelsQuery();
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push(this.loadModelData(row));
@@ -57,7 +60,8 @@ export class RegisteredModelService {
let connection = await this.getCurrentConnection();
let updatedModel: RegisteredModel | undefined = undefined;
if (connection) {
let result = await this.runUpdateModelQuery(connection, model);
const query = this.getUpdateModelScript(connection.databaseName, model);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
const row = result.rows[0];
updatedModel = this.loadModelData(row);
@@ -66,7 +70,7 @@ export class RegisteredModelService {
return updatedModel;
}
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
let connection = await this.getCurrentConnection();
if (connection) {
let currentModels = await this.getRegisteredModels();
@@ -93,35 +97,14 @@ export class RegisteredModelService {
return await this._apiWrapper.getCurrentConnection();
}
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
} catch {
return undefined;
}
private getConfigureQuery(currentDatabaseName: string): string {
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
}
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
} catch {
return undefined;
}
}
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
private registeredModelsQuery(): string {
return `
${this.configureTable(databaseName, tableName)}
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created
FROM [${escapedDbName}].dbo.[${escapedTableName}]
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
@@ -133,52 +116,74 @@ export class RegisteredModelService {
* @param databaseName
* @param tableName
*/
private configureTable(databaseName: string, tableName: string): string {
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
private configureTable(): string {
let databaseName = this._config.registeredModelDatabaseName;
let tableName = this._config.registeredModelTableName;
let schemaName = this._config.registeredModelTableSchemaName;
return `
USE [${escapedDbName}]
IF NOT EXISTS (
SELECT [name]
FROM sys.databases
WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}'
)
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}]
GO
USE [${utils.doubleEscapeSingleBrackets(databaseName)}]
IF EXISTS
( SELECT [name]
FROM sys.tables
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
( SELECT [t.name], [s.name]
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name')
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version')
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created')
BEGIN
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
END
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description')
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL
END
Else
BEGIN
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
[artifact_name] [varchar](256) NOT NULL,
[group_path] [varchar](256) NOT NULL,
[artifact_content] [varbinary](max) NOT NULL,
[artifact_initial_size] [bigint] NULL,
[name] [varchar](256) NULL,
[version] [varchar](256) NULL,
[created] [datetime] NULL,
[description] [varchar](256) NULL,
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
(
[artifact_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
END
`;
}
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
let updateScript = `
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
SET
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
WHERE artifact_id = ${model.id}`;
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
return `
USE [${escapedDbName}]
UPDATE ${escapedTableName}
SET
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
WHERE artifact_id = ${model.id};
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
SELECT artifact_id, artifact_name, name, description, version, created
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
WHERE artifact_id = ${model.id};
`;
}