mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
Machine Learning Services Extension - Predict wizard (#9450)
*MLS extension - Added predict wizard
This commit is contained in:
@@ -48,13 +48,19 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
||||
/**
|
||||
* An interface representing registered model
|
||||
*/
|
||||
export interface RegisteredModel {
|
||||
id?: number,
|
||||
artifactName?: string,
|
||||
title?: string,
|
||||
created?: string,
|
||||
version?: string
|
||||
description?: string
|
||||
export interface RegisteredModel extends RegisteredModelDetails {
|
||||
id: number;
|
||||
artifactName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* An interface representing registered model
|
||||
*/
|
||||
export interface RegisteredModelDetails {
|
||||
title: string;
|
||||
created?: string;
|
||||
version?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -12,6 +12,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as utils from '../common/utils';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
import * as os from 'os';
|
||||
|
||||
/**
|
||||
* Service to import model to database
|
||||
@@ -39,8 +40,8 @@ export class ModelImporter {
|
||||
|
||||
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
|
||||
const parts = modelFolderPath.split('\\');
|
||||
modelFolderPath = parts.join('/');
|
||||
let home = utils.makeLinuxPath(os.homedir());
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
|
||||
@@ -51,9 +52,12 @@ export class ModelImporter {
|
||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
|
||||
let scripts: string[] = [
|
||||
'import mlflow.onnx',
|
||||
`tracking_uri = "file://${home}/mlruns"`,
|
||||
'print(tracking_uri)',
|
||||
'import onnx',
|
||||
'from mlflow.tracking.client import MlflowClient',
|
||||
`onx = onnx.load("${modelFolderPath}")`,
|
||||
`mlflow.set_tracking_uri(tracking_uri)`,
|
||||
'client = MlflowClient()',
|
||||
`exp_name = "${experimentId}"`,
|
||||
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
|
||||
|
||||
@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import { Config } from '../configurations/config';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { RegisteredModel } from './interfaces';
|
||||
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
|
||||
import { ModelImporter } from './modelImporter';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
@@ -32,7 +32,10 @@ export class RegisteredModelService {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: RegisteredModel[] = [];
|
||||
if (connection) {
|
||||
let result = await this.runRegisteredModelsListQuery(connection);
|
||||
let query = this.getConfigureQuery(connection.databaseName);
|
||||
await this._queryRunner.safeRunQuery(connection, query);
|
||||
query = this.registeredModelsQuery();
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push(this.loadModelData(row));
|
||||
@@ -57,7 +60,8 @@ export class RegisteredModelService {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
let result = await this.runUpdateModelQuery(connection, model);
|
||||
const query = this.getUpdateModelScript(connection.databaseName, model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
@@ -66,7 +70,7 @@ export class RegisteredModelService {
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let currentModels = await this.getRegisteredModels();
|
||||
@@ -93,35 +97,14 @@ export class RegisteredModelService {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
private getConfigureQuery(currentDatabaseName: string): string {
|
||||
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
|
||||
}
|
||||
|
||||
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
|
||||
private registeredModelsQuery(): string {
|
||||
return `
|
||||
${this.configureTable(databaseName, tableName)}
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM [${escapedDbName}].dbo.[${escapedTableName}]
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||
Order by artifact_id
|
||||
`;
|
||||
@@ -133,52 +116,74 @@ export class RegisteredModelService {
|
||||
* @param databaseName
|
||||
* @param tableName
|
||||
*/
|
||||
private configureTable(databaseName: string, tableName: string): string {
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
private configureTable(): string {
|
||||
let databaseName = this._config.registeredModelDatabaseName;
|
||||
let tableName = this._config.registeredModelTableName;
|
||||
let schemaName = this._config.registeredModelTableSchemaName;
|
||||
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
IF NOT EXISTS (
|
||||
SELECT [name]
|
||||
FROM sys.databases
|
||||
WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}'
|
||||
)
|
||||
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
||||
GO
|
||||
USE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
||||
IF EXISTS
|
||||
( SELECT [name]
|
||||
FROM sys.tables
|
||||
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
( SELECT [t.name], [s.name]
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
||||
)
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name')
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version')
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created')
|
||||
BEGIN
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
END
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description')
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL
|
||||
END
|
||||
Else
|
||||
BEGIN
|
||||
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
|
||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[artifact_name] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NOT NULL,
|
||||
[artifact_content] [varbinary](max) NOT NULL,
|
||||
[artifact_initial_size] [bigint] NULL,
|
||||
[name] [varchar](256) NULL,
|
||||
[version] [varchar](256) NULL,
|
||||
[created] [datetime] NULL,
|
||||
[description] [varchar](256) NULL,
|
||||
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[artifact_id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||
END
|
||||
`;
|
||||
}
|
||||
|
||||
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
|
||||
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
let updateScript = `
|
||||
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||
SET
|
||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||
WHERE artifact_id = ${model.id}`;
|
||||
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
UPDATE ${escapedTableName}
|
||||
SET
|
||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||
WHERE artifact_id = ${model.id};
|
||||
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
|
||||
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user