mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
Machine Learning Services Extension - Predict wizard (#9450)
*MLS extension - Added predict wizard
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
"modelManagement": {
|
||||
"registeredModelsDatabaseName": "MlFlowDB",
|
||||
"registeredModelsTableName": "artifacts",
|
||||
"registeredModelsTableSchemaName": "dbo",
|
||||
"amlModelManagementUrl": "modelmanagement.azureml.net",
|
||||
"amlExperienceUrl": "experiments.azureml.net",
|
||||
"amlApiVersion": "2018-11-19",
|
||||
|
||||
@@ -57,6 +57,10 @@
|
||||
"command": "mls.command.managePackages",
|
||||
"title": "%mls.command.managePackages%"
|
||||
},
|
||||
{
|
||||
"command": "mls.command.predictModel",
|
||||
"title": "%mls.command.predictModel%"
|
||||
},
|
||||
{
|
||||
"command": "mls.command.manageModels",
|
||||
"title": "%mls.command.manageModels%"
|
||||
@@ -110,7 +114,7 @@
|
||||
"mls.command.managePackages",
|
||||
"mls.command.manageLanguages",
|
||||
"mls.command.manageModels",
|
||||
"mls.command.registerModel"
|
||||
"mls.command.predictModel"
|
||||
]
|
||||
}
|
||||
},
|
||||
|
||||
@@ -7,8 +7,9 @@
|
||||
"title.endpoints": "Endpoints",
|
||||
"mls.command.managePackages": "Manage Packages in SQL Server",
|
||||
"mls.command.manageLanguages": "Manage External Languages",
|
||||
"mls.command.manageModels": "Manage Models",
|
||||
"mls.command.registerModel": "Register Model",
|
||||
"mls.command.predictModel": "Make prediction",
|
||||
"mls.command.manageModels": "Manage models",
|
||||
"mls.command.registerModel": "Register model",
|
||||
"mls.command.odbcdriver": "Install ODBC Driver for SQL Server",
|
||||
"mls.command.mlsdocs": "Machine Learning Services Documentation",
|
||||
"mls.configuration.title": "Machine Learning Services configurations",
|
||||
|
||||
@@ -105,4 +105,28 @@ export class ApiWrapper {
|
||||
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
|
||||
return vscode.window.showQuickPick(items, options, token);
|
||||
}
|
||||
|
||||
public listDatabases(connectionId: string): Thenable<string[]> {
|
||||
return azdata.connection.listDatabases(connectionId);
|
||||
}
|
||||
|
||||
public openTextDocument(options?: { language?: string; content?: string; }): Thenable<vscode.TextDocument> {
|
||||
return vscode.workspace.openTextDocument(options);
|
||||
}
|
||||
|
||||
public connect(fileUri: string, connectionId: string): Thenable<void> {
|
||||
return azdata.queryeditor.connect(fileUri, connectionId);
|
||||
}
|
||||
|
||||
public runQuery(fileUri: string, options?: Map<string, string>, runCurrentQuery?: boolean): void {
|
||||
azdata.queryeditor.runQuery(fileUri, options, runCurrentQuery);
|
||||
}
|
||||
|
||||
public showTextDocument(uri: vscode.Uri, options?: vscode.TextDocumentShowOptions): Thenable<vscode.TextEditor> {
|
||||
return vscode.window.showTextDocument(uri, options);
|
||||
}
|
||||
|
||||
public createButton(label: string, position?: azdata.window.DialogButtonPosition): azdata.window.Button {
|
||||
return azdata.window.createButton(label, position);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups';
|
||||
// Tasks, commands
|
||||
//
|
||||
export const mlManageLanguagesCommand = 'mls.command.manageLanguages';
|
||||
export const mlsPredictModelCommand = 'mls.command.predictModel';
|
||||
export const mlManageModelsCommand = 'mls.command.manageModels';
|
||||
export const mlRegisterModelCommand = 'mls.command.registerModel';
|
||||
export const mlManagePackagesCommand = 'mls.command.managePackages';
|
||||
@@ -116,6 +117,12 @@ export const modelCreated = localize('models.created', "Date Created");
|
||||
export const modelVersion = localize('models.version', "Version");
|
||||
export const browseModels = localize('models.browseButton', "...");
|
||||
export const azureAccount = localize('models.azureAccount', "Azure account");
|
||||
export const columnDatabase = localize('predict.columnDatabase', "Database");
|
||||
export const columnTable = localize('predict.columnTable', "Table");
|
||||
export const inputColumns = localize('predict.inputColumns', "Input columns");
|
||||
export const outputColumns = localize('predict.outputColumns', "Output column");
|
||||
export const columnName = localize('predict.columnName', "Name");
|
||||
export const inputName = localize('predict.inputName', "Input Name");
|
||||
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
|
||||
export const azureGroup = localize('models.azureGroup', "Azure resource group");
|
||||
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
|
||||
@@ -125,18 +132,25 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model
|
||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Ender model source details");
|
||||
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Select input columns");
|
||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
|
||||
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||
export const azureRegisterModel = localize('models.azureRegisterModel', "Register");
|
||||
export const predictModel = localize('models.predictModel', "Predict");
|
||||
export const registerModelTitle = localize('models.RegisterWizard', "Register model");
|
||||
export const makePredictionTitle = localize('models.makePredictionTitle', "Make prediction");
|
||||
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
|
||||
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
|
||||
export const localModelSource = localize('models.localModelSource', "Upload file");
|
||||
export const azureModelSource = localize('models.azureModelSource', "Import from AzureML registry");
|
||||
export const registeredModelsSource = localize('models.registeredModelsSource', "Select managed models");
|
||||
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
|
||||
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
|
||||
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
||||
export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict");
|
||||
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
||||
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");
|
||||
|
||||
|
||||
@@ -163,4 +163,21 @@ export class QueryRunner {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the query but doesn't fail it is fails
|
||||
* @param connection SQL connection
|
||||
* @param query query to run
|
||||
*/
|
||||
public async safeRunQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this.runQuery(connection, query);
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import * as fs from 'fs';
|
||||
import * as constants from '../common/constants';
|
||||
import { promisify } from 'util';
|
||||
import { ApiWrapper } from './apiWrapper';
|
||||
import { Config } from '../configurations/config';
|
||||
|
||||
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
|
||||
let tempFilePath: string = '';
|
||||
@@ -25,6 +26,11 @@ export async function execCommandOnTempFile<T>(content: string, command: (filePa
|
||||
}
|
||||
}
|
||||
|
||||
export async function readFileInHex(filePath: string): Promise<string> {
|
||||
let buffer = await fs.promises.readFile(filePath);
|
||||
return `0X${buffer.toString('hex')}`;
|
||||
}
|
||||
|
||||
export async function exists(path: string): Promise<boolean> {
|
||||
return promisify(fs.exists)(path);
|
||||
}
|
||||
@@ -109,8 +115,8 @@ export function isWindows(): boolean {
|
||||
* ' => ''
|
||||
* @param value The string to escape
|
||||
*/
|
||||
export function doubleEscapeSingleQuotes(value: string): string {
|
||||
return value.replace(/'/g, '\'\'');
|
||||
export function doubleEscapeSingleQuotes(value: string | undefined): string {
|
||||
return value ? value.replace(/'/g, '\'\'') : '';
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -118,8 +124,8 @@ export function doubleEscapeSingleQuotes(value: string): string {
|
||||
* ' => ''
|
||||
* @param value The string to escape
|
||||
*/
|
||||
export function doubleEscapeSingleBrackets(value: string): string {
|
||||
return value.replace(/\[/g, '[[').replace(/\]/g, ']]');
|
||||
export function doubleEscapeSingleBrackets(value: string | undefined): string {
|
||||
return value ? value.replace(/\[/g, '[[').replace(/\]/g, ']]') : '';
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -176,3 +182,48 @@ export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Pr
|
||||
|
||||
return choices[result.label] || false;
|
||||
}
|
||||
|
||||
export function makeLinuxPath(filePath: string): string {
|
||||
const parts = filePath.split('\\');
|
||||
return parts.join('/');
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param currentDb Wraps the given script with database switch scripts
|
||||
* @param databaseName
|
||||
* @param script
|
||||
*/
|
||||
export function getScriptWithDBChange(currentDb: string, databaseName: string, script: string): string {
|
||||
if (!currentDb) {
|
||||
currentDb = 'master';
|
||||
}
|
||||
let escapedDbName = doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = doubleEscapeSingleBrackets(currentDb);
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
${script}
|
||||
USE [${escapedCurrentDbName}]
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns full name of model registration table
|
||||
* @param config config
|
||||
*/
|
||||
export function getRegisteredModelsThreePartsName(config: Config) {
|
||||
const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName);
|
||||
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
|
||||
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
||||
return `[${dbName}].${schema}.[${tableName}]`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns full name of model registration table
|
||||
* @param config config object
|
||||
*/
|
||||
export function getRegisteredModelsTowPartsName(config: Config) {
|
||||
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
|
||||
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
||||
return `[${schema}].[${tableName}]`;
|
||||
}
|
||||
|
||||
@@ -82,6 +82,13 @@ export class Config {
|
||||
return this._configValues.modelManagement.registeredModelsTableName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns registered models table schema name
|
||||
*/
|
||||
public get registeredModelTableSchemaName(): string {
|
||||
return this._configValues.modelManagement.registeredModelsTableSchemaName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns registered models table name
|
||||
*/
|
||||
|
||||
@@ -22,6 +22,7 @@ import { ModelManagementController } from '../views/models/modelManagementContro
|
||||
import { RegisteredModelService } from '../modelManagement/registeredModelService';
|
||||
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
|
||||
import { ModelImporter } from '../modelManagement/modelImporter';
|
||||
import { PredictService } from '../prediction/predictService';
|
||||
|
||||
/**
|
||||
* The main controller class that initializes the extension
|
||||
@@ -109,7 +110,9 @@ export default class MainController implements vscode.Disposable {
|
||||
//
|
||||
let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
|
||||
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
|
||||
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, azureModelsService, registeredModelService);
|
||||
let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config);
|
||||
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
|
||||
azureModelsService, registeredModelService, predictService);
|
||||
|
||||
this._apiWrapper.registerCommand(constants.mlManageLanguagesCommand, (async () => {
|
||||
await languageController.manageLanguages();
|
||||
@@ -120,6 +123,9 @@ export default class MainController implements vscode.Disposable {
|
||||
this._apiWrapper.registerCommand(constants.mlRegisterModelCommand, (async () => {
|
||||
await modelManagementController.registerModel();
|
||||
}));
|
||||
this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => {
|
||||
await modelManagementController.predictModel();
|
||||
}));
|
||||
this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => {
|
||||
await packageManager.installDependencies();
|
||||
}));
|
||||
@@ -135,6 +141,9 @@ export default class MainController implements vscode.Disposable {
|
||||
this._apiWrapper.registerTaskHandler(constants.mlRegisterModelCommand, async () => {
|
||||
await modelManagementController.registerModel();
|
||||
});
|
||||
this._apiWrapper.registerTaskHandler(constants.mlsPredictModelCommand, async () => {
|
||||
await modelManagementController.predictModel();
|
||||
});
|
||||
this._apiWrapper.registerTaskHandler(constants.mlOdbcDriverCommand, async () => {
|
||||
await this.serverConfigManager.openOdbcDriverDocuments();
|
||||
});
|
||||
|
||||
@@ -48,13 +48,19 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
||||
/**
|
||||
* An interface representing registered model
|
||||
*/
|
||||
export interface RegisteredModel {
|
||||
id?: number,
|
||||
artifactName?: string,
|
||||
title?: string,
|
||||
created?: string,
|
||||
version?: string
|
||||
description?: string
|
||||
export interface RegisteredModel extends RegisteredModelDetails {
|
||||
id: number;
|
||||
artifactName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* An interface representing registered model
|
||||
*/
|
||||
export interface RegisteredModelDetails {
|
||||
title: string;
|
||||
created?: string;
|
||||
version?: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -12,6 +12,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as utils from '../common/utils';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
import * as os from 'os';
|
||||
|
||||
/**
|
||||
* Service to import model to database
|
||||
@@ -39,8 +40,8 @@ export class ModelImporter {
|
||||
|
||||
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
|
||||
const parts = modelFolderPath.split('\\');
|
||||
modelFolderPath = parts.join('/');
|
||||
let home = utils.makeLinuxPath(os.homedir());
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
|
||||
@@ -51,9 +52,12 @@ export class ModelImporter {
|
||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
|
||||
let scripts: string[] = [
|
||||
'import mlflow.onnx',
|
||||
`tracking_uri = "file://${home}/mlruns"`,
|
||||
'print(tracking_uri)',
|
||||
'import onnx',
|
||||
'from mlflow.tracking.client import MlflowClient',
|
||||
`onx = onnx.load("${modelFolderPath}")`,
|
||||
`mlflow.set_tracking_uri(tracking_uri)`,
|
||||
'client = MlflowClient()',
|
||||
`exp_name = "${experimentId}"`,
|
||||
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
|
||||
|
||||
@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import { Config } from '../configurations/config';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { RegisteredModel } from './interfaces';
|
||||
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
|
||||
import { ModelImporter } from './modelImporter';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
@@ -32,7 +32,10 @@ export class RegisteredModelService {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: RegisteredModel[] = [];
|
||||
if (connection) {
|
||||
let result = await this.runRegisteredModelsListQuery(connection);
|
||||
let query = this.getConfigureQuery(connection.databaseName);
|
||||
await this._queryRunner.safeRunQuery(connection, query);
|
||||
query = this.registeredModelsQuery();
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push(this.loadModelData(row));
|
||||
@@ -57,7 +60,8 @@ export class RegisteredModelService {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
let result = await this.runUpdateModelQuery(connection, model);
|
||||
const query = this.getUpdateModelScript(connection.databaseName, model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
@@ -66,7 +70,7 @@ export class RegisteredModelService {
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let currentModels = await this.getRegisteredModels();
|
||||
@@ -93,35 +97,14 @@ export class RegisteredModelService {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
private getConfigureQuery(currentDatabaseName: string): string {
|
||||
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
|
||||
}
|
||||
|
||||
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
|
||||
private registeredModelsQuery(): string {
|
||||
return `
|
||||
${this.configureTable(databaseName, tableName)}
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM [${escapedDbName}].dbo.[${escapedTableName}]
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||
Order by artifact_id
|
||||
`;
|
||||
@@ -133,52 +116,74 @@ export class RegisteredModelService {
|
||||
* @param databaseName
|
||||
* @param tableName
|
||||
*/
|
||||
private configureTable(databaseName: string, tableName: string): string {
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
private configureTable(): string {
|
||||
let databaseName = this._config.registeredModelDatabaseName;
|
||||
let tableName = this._config.registeredModelTableName;
|
||||
let schemaName = this._config.registeredModelTableSchemaName;
|
||||
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
IF NOT EXISTS (
|
||||
SELECT [name]
|
||||
FROM sys.databases
|
||||
WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}'
|
||||
)
|
||||
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
||||
GO
|
||||
USE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
||||
IF EXISTS
|
||||
( SELECT [name]
|
||||
FROM sys.tables
|
||||
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
( SELECT [t.name], [s.name]
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
||||
)
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name')
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version')
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created')
|
||||
BEGIN
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
END
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description')
|
||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL
|
||||
END
|
||||
Else
|
||||
BEGIN
|
||||
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
|
||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[artifact_name] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NOT NULL,
|
||||
[artifact_content] [varbinary](max) NOT NULL,
|
||||
[artifact_initial_size] [bigint] NULL,
|
||||
[name] [varchar](256) NULL,
|
||||
[version] [varchar](256) NULL,
|
||||
[created] [datetime] NULL,
|
||||
[description] [varchar](256) NULL,
|
||||
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[artifact_id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||
END
|
||||
`;
|
||||
}
|
||||
|
||||
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
|
||||
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
let updateScript = `
|
||||
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||
SET
|
||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||
WHERE artifact_id = ${model.id}`;
|
||||
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
UPDATE ${escapedTableName}
|
||||
SET
|
||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||
WHERE artifact_id = ${model.id};
|
||||
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
|
||||
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
export interface PredictColumn {
|
||||
name: string;
|
||||
dataType?: string;
|
||||
displayName?: string;
|
||||
}
|
||||
|
||||
export interface DatabaseTable {
|
||||
databaseName: string | undefined;
|
||||
tableName: string | undefined;
|
||||
schema: string | undefined
|
||||
}
|
||||
|
||||
export interface PredictInputParameters extends DatabaseTable {
|
||||
inputColumns: PredictColumn[] | undefined
|
||||
}
|
||||
|
||||
export interface PredictParameters extends PredictInputParameters {
|
||||
outputColumns: PredictColumn[] | undefined
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import * as utils from '../common/utils';
|
||||
import { RegisteredModel } from '../modelManagement/interfaces';
|
||||
import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces';
|
||||
import { Config } from '../configurations/config';
|
||||
|
||||
/**
|
||||
* Service to make prediction
|
||||
*/
|
||||
export class PredictService {
|
||||
|
||||
/**
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _queryRunner: QueryRunner,
|
||||
private _config: Config) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the list of databases
|
||||
*/
|
||||
public async getDatabaseList(): Promise<string[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
return await this._apiWrapper.listDatabases(connection.connectionId);
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates prediction script given model info and predict parameters
|
||||
* @param predictParams predict parameters
|
||||
* @param registeredModel model parameters
|
||||
*/
|
||||
public async generatePredictScript(
|
||||
predictParams: PredictParameters,
|
||||
registeredModel: RegisteredModel | undefined,
|
||||
filePath: string | undefined
|
||||
): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let query = '';
|
||||
if (registeredModel && registeredModel.id) {
|
||||
query = this.getPredictScriptWithModelId(
|
||||
registeredModel.id,
|
||||
predictParams.inputColumns || [],
|
||||
predictParams.outputColumns || [],
|
||||
predictParams);
|
||||
} else if (filePath) {
|
||||
let modelBytes = await utils.readFileInHex(filePath || '');
|
||||
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
|
||||
predictParams.outputColumns || [],
|
||||
predictParams);
|
||||
}
|
||||
let document = await this._apiWrapper.openTextDocument({
|
||||
language: 'sql',
|
||||
content: query
|
||||
});
|
||||
await this._apiWrapper.showTextDocument(document.uri);
|
||||
await this._apiWrapper.connect(document.uri.toString(), connection.connectionId);
|
||||
this._apiWrapper.runQuery(document.uri.toString(), undefined, false);
|
||||
return query;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of tables given database name
|
||||
* @param databaseName database name
|
||||
*/
|
||||
public async getTableList(databaseName: string): Promise<DatabaseTable[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: DatabaseTable[] = [];
|
||||
if (connection) {
|
||||
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName));
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push({
|
||||
databaseName: databaseName,
|
||||
tableName: row[0].displayValue,
|
||||
schema: row[1].displayValue
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
/**
|
||||
*Returns list of column names of a database
|
||||
* @param databaseTable table info
|
||||
*/
|
||||
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<string[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: string[] = [];
|
||||
if (connection && databaseTable.databaseName) {
|
||||
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push(row[0].displayValue);
|
||||
});
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private getTableColumnsScript(databaseTable: DatabaseTable): string {
|
||||
return `
|
||||
SELECT COLUMN_NAME,*
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
|
||||
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
|
||||
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
|
||||
`;
|
||||
}
|
||||
|
||||
private getTablesScript(databaseName: string): string {
|
||||
return `
|
||||
SELECT TABLE_NAME,TABLE_SCHEMA
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
|
||||
`;
|
||||
}
|
||||
|
||||
private getPredictScriptWithModelId(
|
||||
modelId: number,
|
||||
columns: PredictColumn[],
|
||||
outputColumns: PredictColumn[],
|
||||
databaseNameTable: DatabaseTable): string {
|
||||
return `
|
||||
DECLARE @model VARBINARY(max) = (
|
||||
SELECT artifact_content
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_id = ${modelId}
|
||||
);
|
||||
WITH predict_input
|
||||
AS (
|
||||
SELECT TOP 1000
|
||||
${this.getColumnNames(columns, 'pi')}
|
||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||
)
|
||||
SELECT
|
||||
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
|
||||
FROM PREDICT(MODEL = @model, DATA = predict_input)
|
||||
WITH (
|
||||
${this.getColumnTypes(outputColumns)}
|
||||
) AS p
|
||||
`;
|
||||
}
|
||||
|
||||
private getPredictScriptWithModelBytes(
|
||||
modelBytes: string,
|
||||
columns: PredictColumn[],
|
||||
outputColumns: PredictColumn[],
|
||||
databaseNameTable: DatabaseTable): string {
|
||||
return `
|
||||
WITH predict_input
|
||||
AS (
|
||||
SELECT TOP 1000
|
||||
${this.getColumnNames(columns, 'pi')}
|
||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||
)
|
||||
SELECT
|
||||
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
|
||||
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
|
||||
WITH (
|
||||
${this.getColumnTypes(outputColumns)}
|
||||
) AS p
|
||||
`;
|
||||
}
|
||||
|
||||
private getColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getColumnTypes(columns: PredictColumn[]) {
|
||||
return columns.map(c => {
|
||||
return `${c.name} ${c.dataType}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { RegisterModelWizard } from '../../../views/models/registerModelWizard';
|
||||
import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard';
|
||||
|
||||
describe('Register Model Wizard', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
@@ -74,7 +74,8 @@ describe('Register Model Wizard', () => {
|
||||
let localModels: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model'
|
||||
artifactName: 'model',
|
||||
title: 'model'
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { RegisteredModelsDialog } from '../../../views/models/registeredModelsDialog';
|
||||
import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog';
|
||||
import { ListModelsEventName } from '../../../views/models/modelViewBase';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
@@ -30,7 +30,8 @@ describe('Registered Models Dialog', () => {
|
||||
let models: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model'
|
||||
artifactName: 'model',
|
||||
title: ''
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
|
||||
@@ -246,6 +246,7 @@ export function createViewContext(): ViewTestContext {
|
||||
modelView: undefined!,
|
||||
valid: true
|
||||
};
|
||||
apiWrapper.setup(x => x.createButton(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => dialogButton);
|
||||
apiWrapper.setup(x => x.createTab(TypeMoq.It.isAny())).returns(() => tab);
|
||||
apiWrapper.setup(x => x.createWizard(TypeMoq.It.isAny())).returns(() => wizard);
|
||||
apiWrapper.setup(x => x.createWizardPage(TypeMoq.It.isAny())).returns(() => wizardPage);
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { ViewBase, LocalFileEventName, LocalFolderEventName } from './viewBase';
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
import { ViewBase, LocalPathsEventName } from './viewBase';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
|
||||
/**
|
||||
@@ -36,11 +38,8 @@ export abstract class ControllerBase {
|
||||
* @param view view
|
||||
*/
|
||||
public registerEvents(view: ViewBase): void {
|
||||
view.on(LocalFileEventName, async () => {
|
||||
await this.executeAction(view, LocalFileEventName, this.getLocalFilePath, this._apiWrapper);
|
||||
});
|
||||
view.on(LocalFolderEventName, async () => {
|
||||
await this.executeAction(view, LocalFolderEventName, this.getLocalFolderPath, this._apiWrapper);
|
||||
view.on(LocalPathsEventName, async (args) => {
|
||||
await this.executeAction(view, LocalPathsEventName, this.getLocalPaths, this._apiWrapper, args);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -48,25 +47,8 @@ export abstract class ControllerBase {
|
||||
* Returns local file path picked by the user
|
||||
* @param apiWrapper apiWrapper
|
||||
*/
|
||||
public async getLocalFilePath(apiWrapper: ApiWrapper): Promise<string> {
|
||||
let result = await apiWrapper.showOpenDialog({
|
||||
canSelectFiles: true,
|
||||
canSelectFolders: false,
|
||||
canSelectMany: false
|
||||
});
|
||||
return result && result.length > 0 ? result[0].fsPath : '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns local folder path picked by the user
|
||||
* @param apiWrapper apiWrapper
|
||||
*/
|
||||
public async getLocalFolderPath(apiWrapper: ApiWrapper): Promise<string> {
|
||||
let result = await apiWrapper.showOpenDialog({
|
||||
canSelectFiles: false,
|
||||
canSelectFolders: true,
|
||||
canSelectMany: false
|
||||
});
|
||||
return result && result.length > 0 ? result[0].fsPath : '';
|
||||
public async getLocalPaths(apiWrapper: ApiWrapper, options: vscode.OpenDialogOptions): Promise<string[]> {
|
||||
let result = await apiWrapper.showOpenDialog(options);
|
||||
return result ? result?.map(x => x.fsPath) : [];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ export interface IPageView {
|
||||
component: azdata.Component | undefined;
|
||||
onEnter?: () => Promise<void>;
|
||||
onLeave?: () => Promise<void>;
|
||||
validate?: () => Promise<boolean>;
|
||||
refresh: () => Promise<void>;
|
||||
viewPanel: azdata.window.ModelViewPanel | undefined;
|
||||
title: string;
|
||||
@@ -32,3 +33,4 @@ export interface AzureModelResource extends AzureWorkspaceResource {
|
||||
model?: WorkspaceModel;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
@@ -43,9 +45,17 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen
|
||||
}
|
||||
}).component();
|
||||
this._localBrowse.onDidClick(async () => {
|
||||
const filePath = await this.getLocalFilePath();
|
||||
|
||||
let options: vscode.OpenDialogOptions = {
|
||||
canSelectFiles: true,
|
||||
canSelectFolders: false,
|
||||
canSelectMany: false,
|
||||
filters: { 'ONNX File': ['onnx'] }
|
||||
};
|
||||
|
||||
const filePaths = await this.getLocalPaths(options);
|
||||
if (this._localPath) {
|
||||
this._localPath.value = filePath;
|
||||
this._localPath.value = filePaths && filePaths.length > 0 ? filePaths[0] : '';
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -8,12 +8,12 @@ import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick local models file
|
||||
*/
|
||||
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModel> {
|
||||
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModelDetails> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _nameComponent: azdata.InputBoxComponent | undefined;
|
||||
@@ -81,9 +81,9 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): RegisteredModel {
|
||||
public get data(): RegisteredModelDetails {
|
||||
return {
|
||||
title: this._nameComponent?.value,
|
||||
title: this._nameComponent?.value || '',
|
||||
description: this._descriptionComponent?.value
|
||||
};
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { ModelDetailsComponent } from './modelDetailsComponent';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick model details
|
||||
*/
|
||||
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModel> {
|
||||
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModelDetails> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
@@ -43,7 +43,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): RegisteredModel | undefined {
|
||||
public get data(): RegisteredModelDetails | undefined {
|
||||
return this.modelDetails?.data;
|
||||
}
|
||||
|
||||
@@ -66,4 +66,13 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
||||
public get title(): string {
|
||||
return constants.modelDetailsPageTitle;
|
||||
}
|
||||
|
||||
public validate(): Promise<boolean> {
|
||||
if (this.data && this.data.title) {
|
||||
return Promise.resolve(true);
|
||||
} else {
|
||||
this.showErrorMessage(constants.modelNameRequiredError);
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,14 +9,23 @@ import { azureResource } from '../../typings/azure-resource';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
|
||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
|
||||
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
|
||||
import { RegisteredModelsDialog } from './registeredModelsDialog';
|
||||
import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName } from './modelViewBase';
|
||||
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
|
||||
import {
|
||||
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
|
||||
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
|
||||
} from './modelViewBase';
|
||||
import { ControllerBase } from '../controllerBase';
|
||||
import { RegisterModelWizard } from './registerModelWizard';
|
||||
import { RegisterModelWizard } from './registerModels/registerModelWizard';
|
||||
import * as fs from 'fs';
|
||||
import * as constants from '../../common/constants';
|
||||
import { PredictWizard } from './prediction/predictWizard';
|
||||
import { AzureModelResource } from '../interfaces';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
|
||||
/**
|
||||
* Model management UI controller
|
||||
@@ -30,7 +39,8 @@ export class ModelManagementController extends ControllerBase {
|
||||
apiWrapper: ApiWrapper,
|
||||
private _root: string,
|
||||
private _amlService: AzureModelRegistryService,
|
||||
private _registeredModelService: RegisteredModelService) {
|
||||
private _registeredModelService: RegisteredModelService,
|
||||
private _predictService: PredictService) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
|
||||
@@ -56,6 +66,23 @@ export class ModelManagementController extends ControllerBase {
|
||||
return view;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the wizard for prediction
|
||||
*/
|
||||
public async predictModel(): Promise<ModelViewBase> {
|
||||
|
||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||
|
||||
this.registerEvents(view);
|
||||
|
||||
// Open view
|
||||
//
|
||||
view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Register events in the main view
|
||||
* @param view main view
|
||||
@@ -102,6 +129,28 @@ export class ModelManagementController extends ControllerBase {
|
||||
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
|
||||
});
|
||||
view.on(DownloadAzureModelEventName, async (arg) => {
|
||||
let registerArgs = <AzureModelResource>arg;
|
||||
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, async () => {
|
||||
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
|
||||
});
|
||||
view.on(ListTableNamesEventName, async (arg) => {
|
||||
let dbName = <string>arg;
|
||||
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
|
||||
});
|
||||
view.on(ListColumnNamesEventName, async (arg) => {
|
||||
let tableColumnsArgs = <DatabaseTable>arg;
|
||||
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
|
||||
tableColumnsArgs);
|
||||
});
|
||||
view.on(PredictModelEventName, async (arg) => {
|
||||
let predictArgs = <PredictModelEventArgs>arg;
|
||||
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
|
||||
predictArgs, predictArgs.model, predictArgs.filePath);
|
||||
});
|
||||
view.on(SourceModelSelectedEventName, () => {
|
||||
view.refresh();
|
||||
});
|
||||
@@ -158,7 +207,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
||||
}
|
||||
|
||||
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise<void> {
|
||||
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
if (filePath) {
|
||||
await service.registerLocalModel(filePath, details);
|
||||
} else {
|
||||
@@ -175,7 +224,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
workspace: Workspace | undefined,
|
||||
model: WorkspaceModel | undefined,
|
||||
details: RegisteredModel | undefined): Promise<void> {
|
||||
details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
|
||||
throw Error(constants.invalidAzureResourceError);
|
||||
}
|
||||
@@ -188,4 +237,47 @@ export class ModelManagementController extends ControllerBase {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
|
||||
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
||||
return await predictService.getDatabaseList();
|
||||
}
|
||||
|
||||
public async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
|
||||
return await predictService.getTableList(databaseName);
|
||||
}
|
||||
|
||||
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
|
||||
return await predictService.getTableColumnsList(databaseTable);
|
||||
}
|
||||
|
||||
private async generatePredictScript(
|
||||
predictService: PredictService,
|
||||
predictParams: PredictParameters,
|
||||
registeredModel: RegisteredModel | undefined,
|
||||
filePath: string | undefined
|
||||
): Promise<string> {
|
||||
if (!predictParams) {
|
||||
throw Error(constants.invalidModelToPredictError);
|
||||
}
|
||||
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
|
||||
return result;
|
||||
}
|
||||
|
||||
private async downloadAzureModel(
|
||||
azureService: AzureModelRegistryService,
|
||||
account: azdata.Account | undefined,
|
||||
subscription: azureResource.AzureResourceSubscription | undefined,
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
workspace: Workspace | undefined,
|
||||
model: WorkspaceModel | undefined): Promise<string> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace || !model) {
|
||||
throw Error(constants.invalidAzureResourceError);
|
||||
}
|
||||
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
|
||||
if (filePath) {
|
||||
return filePath;
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
|
||||
import { LocalModelsComponent } from './localModelsComponent';
|
||||
import { AzureModelsComponent } from './azureModelsComponent';
|
||||
import { CurrentModelsTable } from './registerModels/currentModelsTable';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
@@ -22,8 +23,9 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
||||
public modelResources: ModelSourcesComponent | undefined;
|
||||
public localModelsComponent: LocalModelsComponent | undefined;
|
||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||
public registeredModelsComponent: CurrentModelsTable | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
@@ -34,13 +36,15 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
|
||||
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
|
||||
this.modelResources.registerComponent(modelBuilder);
|
||||
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
|
||||
this.localModelsComponent.registerComponent(modelBuilder);
|
||||
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
|
||||
this.azureModelsComponent.registerComponent(modelBuilder);
|
||||
this.modelResources.addComponents(this._formBuilder);
|
||||
this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this);
|
||||
this.registeredModelsComponent.registerComponent(modelBuilder);
|
||||
this.refresh();
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
@@ -66,19 +70,29 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._formBuilder) {
|
||||
if (this.modelResources && this.modelResources.data === ModelSourceType.Local) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||
this.localModelsComponent.addComponents(this._formBuilder);
|
||||
await this.localModelsComponent.refresh();
|
||||
}
|
||||
|
||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||
this.azureModelsComponent.addComponents(this._formBuilder);
|
||||
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||
await this.azureModelsComponent.refresh();
|
||||
}
|
||||
|
||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||
this.registeredModelsComponent.addComponents(this._formBuilder);
|
||||
await this.registeredModelsComponent.refresh();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -89,4 +103,21 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
||||
public get title(): string {
|
||||
return constants.modelSourcePageTitle;
|
||||
}
|
||||
|
||||
public validate(): Promise<boolean> {
|
||||
let validated = false;
|
||||
if (this.modelResources && this.modelResources.data === ModelSourceType.Local && this.localModelsComponent) {
|
||||
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
|
||||
|
||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure && this.azureModelsComponent) {
|
||||
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.model !== undefined;
|
||||
|
||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
|
||||
validated = this.registeredModelsComponent.data !== undefined;
|
||||
}
|
||||
if (!validated) {
|
||||
this.showErrorMessage(constants.invalidModelToSelectError);
|
||||
}
|
||||
return Promise.resolve(validated);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,8 @@ import { IDataComponent } from '../interfaces';
|
||||
|
||||
export enum ModelSourceType {
|
||||
Local,
|
||||
Azure
|
||||
Azure,
|
||||
RegisteredModels
|
||||
}
|
||||
/**
|
||||
* View to pick model source
|
||||
@@ -22,9 +23,10 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
private _flexContainer: azdata.FlexContainer | undefined;
|
||||
private _amlModel: azdata.RadioButtonComponent | undefined;
|
||||
private _localModel: azdata.RadioButtonComponent | undefined;
|
||||
private _isLocalModel: boolean = true;
|
||||
private _registeredModels: azdata.RadioButtonComponent | undefined;
|
||||
private _sourceType: ModelSourceType = ModelSourceType.Local;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
@@ -38,7 +40,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
value: 'local',
|
||||
name: 'modelLocation',
|
||||
label: constants.localModelSource,
|
||||
checked: true
|
||||
checked: this._options[0] === ModelSourceType.Local
|
||||
}).component();
|
||||
|
||||
|
||||
@@ -47,26 +49,58 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
value: 'aml',
|
||||
name: 'modelLocation',
|
||||
label: constants.azureModelSource,
|
||||
checked: this._options[0] === ModelSourceType.Azure
|
||||
}).component();
|
||||
|
||||
this._registeredModels = modelBuilder.radioButton()
|
||||
.withProperties({
|
||||
value: 'registered',
|
||||
name: 'modelLocation',
|
||||
label: constants.registeredModelsSource,
|
||||
checked: this._options[0] === ModelSourceType.RegisteredModels
|
||||
}).component();
|
||||
|
||||
this._localModel.onDidClick(() => {
|
||||
this._isLocalModel = true;
|
||||
this._sourceType = ModelSourceType.Local;
|
||||
this.sendRequest(SourceModelSelectedEventName);
|
||||
|
||||
});
|
||||
this._amlModel.onDidClick(() => {
|
||||
this._isLocalModel = false;
|
||||
this._sourceType = ModelSourceType.Azure;
|
||||
this.sendRequest(SourceModelSelectedEventName);
|
||||
});
|
||||
this._registeredModels.onDidClick(() => {
|
||||
this._sourceType = ModelSourceType.RegisteredModels;
|
||||
this.sendRequest(SourceModelSelectedEventName);
|
||||
});
|
||||
let components: azdata.RadioButtonComponent[] = [];
|
||||
|
||||
this._options.forEach(option => {
|
||||
switch (option) {
|
||||
case ModelSourceType.Local:
|
||||
if (this._localModel) {
|
||||
components.push(this._localModel);
|
||||
}
|
||||
break;
|
||||
case ModelSourceType.Azure:
|
||||
if (this._amlModel) {
|
||||
components.push(this._amlModel);
|
||||
}
|
||||
break;
|
||||
case ModelSourceType.RegisteredModels:
|
||||
if (this._registeredModels) {
|
||||
components.push(this._registeredModels);
|
||||
}
|
||||
break;
|
||||
}
|
||||
});
|
||||
this._sourceType = this._options[0];
|
||||
|
||||
this._flexContainer = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'column',
|
||||
justifyContent: 'space-between'
|
||||
}).withItems([
|
||||
this._localModel, this._amlModel]
|
||||
).component();
|
||||
}).withItems(components).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
@@ -92,7 +126,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ModelSourceType {
|
||||
return this._isLocalModel ? ModelSourceType.Local : ModelSourceType.Azure;
|
||||
return this._sourceType;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -8,7 +8,8 @@ import * as azdata from 'azdata';
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { ViewBase } from '../viewBase';
|
||||
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
|
||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
||||
|
||||
@@ -16,13 +17,18 @@ export interface AzureResourceEventArgs extends AzureWorkspaceResource {
|
||||
}
|
||||
|
||||
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
|
||||
details?: RegisteredModel
|
||||
details?: RegisteredModelDetails
|
||||
}
|
||||
|
||||
export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs {
|
||||
model?: WorkspaceModel;
|
||||
}
|
||||
|
||||
export interface PredictModelEventArgs extends PredictParameters {
|
||||
model?: RegisteredModel;
|
||||
filePath?: string;
|
||||
}
|
||||
|
||||
export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
|
||||
filePath?: string;
|
||||
}
|
||||
@@ -32,11 +38,16 @@ export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
|
||||
export const ListModelsEventName = 'listModels';
|
||||
export const ListAzureModelsEventName = 'listAzureModels';
|
||||
export const ListAccountsEventName = 'listAccounts';
|
||||
export const ListDatabaseNamesEventName = 'listDatabaseNames';
|
||||
export const ListTableNamesEventName = 'listTableNames';
|
||||
export const ListColumnNamesEventName = 'listColumnNames';
|
||||
export const ListSubscriptionsEventName = 'listSubscriptions';
|
||||
export const ListGroupsEventName = 'listGroups';
|
||||
export const ListWorkspacesEventName = 'listWorkspaces';
|
||||
export const RegisterLocalModelEventName = 'registerLocalModel';
|
||||
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
|
||||
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
|
||||
export const PredictModelEventName = 'predictModel';
|
||||
export const RegisterModelEventName = 'registerModel';
|
||||
export const SourceModelSelectedEventName = 'sourceModelSelected';
|
||||
|
||||
@@ -59,7 +70,12 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
RegisterLocalModelEventName,
|
||||
RegisterAzureModelEventName,
|
||||
RegisterModelEventName,
|
||||
SourceModelSelectedEventName]);
|
||||
SourceModelSelectedEventName,
|
||||
ListDatabaseNamesEventName,
|
||||
ListTableNamesEventName,
|
||||
ListColumnNamesEventName,
|
||||
PredictModelEventName,
|
||||
DownloadAzureModelEventName]);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -91,6 +107,27 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
return await this.sendDataRequest(ListAccountsEventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists database names
|
||||
*/
|
||||
public async listDatabaseNames(): Promise<string[]> {
|
||||
return await this.sendDataRequest(ListDatabaseNamesEventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists table names
|
||||
*/
|
||||
public async listTableNames(dbName: string): Promise<DatabaseTable[]> {
|
||||
return await this.sendDataRequest(ListTableNamesEventName, dbName);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists column names
|
||||
*/
|
||||
public async listColumnNames(table: DatabaseTable): Promise<string[]> {
|
||||
return await this.sendDataRequest(ListColumnNamesEventName, table);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists azure subscriptions
|
||||
* @param account azure account
|
||||
@@ -106,7 +143,7 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
* registers local model
|
||||
* @param localFilePath local file path
|
||||
*/
|
||||
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModel | undefined): Promise<void> {
|
||||
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
const args: RegisterLocalModelEventArgs = {
|
||||
filePath: localFilePath,
|
||||
details: details
|
||||
@@ -114,17 +151,38 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
return await this.sendDataRequest(RegisterLocalModelEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* download azure model
|
||||
* @param args azure resource
|
||||
*/
|
||||
public async downloadAzureModel(resource: AzureModelResource | undefined): Promise<string> {
|
||||
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
|
||||
}
|
||||
|
||||
/**
|
||||
* registers azure model
|
||||
* @param args azure resource
|
||||
*/
|
||||
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModel | undefined): Promise<void> {
|
||||
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
const args: RegisterAzureModelEventArgs = Object.assign({}, resource, {
|
||||
details: details
|
||||
});
|
||||
return await this.sendDataRequest(RegisterAzureModelEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* registers azure model
|
||||
* @param args azure resource
|
||||
*/
|
||||
public async generatePredictScript(model: RegisteredModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> {
|
||||
const args: PredictModelEventArgs = Object.assign({}, params, {
|
||||
model: model,
|
||||
filePath: filePath,
|
||||
loadFromRegisteredModel: !filePath
|
||||
});
|
||||
return await this.sendDataRequest(PredictModelEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* list resource groups
|
||||
* @param account azure account
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { ColumnsTable } from './columnsTable';
|
||||
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
export class ColumnsFilterComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _databases: azdata.DropDownComponent | undefined;
|
||||
private _tables: azdata.DropDownComponent | undefined;
|
||||
private _columns: ColumnsTable | undefined;
|
||||
private _dbNames: string[] = [];
|
||||
private _tableNames: DatabaseTable[] = [];
|
||||
|
||||
/**
|
||||
* Creates a new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._databases = modelBuilder.dropDown().withProperties({
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
this._tables = modelBuilder.dropDown().withProperties({
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
|
||||
|
||||
this._databases.onValueChanged(async () => {
|
||||
await this.onDatabaseSelected();
|
||||
});
|
||||
|
||||
this._tables.onValueChanged(async () => {
|
||||
await this.onTableSelected();
|
||||
});
|
||||
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._databases
|
||||
}, {
|
||||
title: constants.azureSubscription,
|
||||
component: this._tables
|
||||
}, {
|
||||
title: constants.azureGroup,
|
||||
component: this._columns.component
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._databases && this._tables && this._columns) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.columnDatabase,
|
||||
component: this._databases
|
||||
}, {
|
||||
title: constants.columnTable,
|
||||
component: this._tables
|
||||
}, {
|
||||
title: constants.inputColumns,
|
||||
component: this._columns.component
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._databases && this._tables && this._columns) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureAccount,
|
||||
component: this._databases
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureSubscription,
|
||||
component: this._tables
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureGroup,
|
||||
component: this._columns.component
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the created component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictInputParameters | undefined {
|
||||
return Object.assign({}, this.databaseTable, {
|
||||
inputColumns: this.columnNames
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* loads data in the components
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
this._dbNames = await this.listDatabaseNames();
|
||||
if (this._databases && this._dbNames && this._dbNames.length > 0) {
|
||||
this._databases.values = this._dbNames;
|
||||
this._databases.value = this._dbNames[0];
|
||||
}
|
||||
await this.onDatabaseSelected();
|
||||
}
|
||||
|
||||
/**
|
||||
* refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
|
||||
private async onDatabaseSelected(): Promise<void> {
|
||||
this._tableNames = await this.listTableNames(this.databaseName || '');
|
||||
if (this._tables && this._tableNames && this._tableNames.length > 0) {
|
||||
this._tables.values = this._tableNames.map(t => this.getTableFullName(t));
|
||||
this._tables.value = this.getTableFullName(this._tableNames[0]);
|
||||
}
|
||||
await this.onTableSelected();
|
||||
}
|
||||
|
||||
private getTableFullName(table: DatabaseTable): string {
|
||||
return `${table.schema}.${table.tableName}`;
|
||||
}
|
||||
|
||||
private async onTableSelected(): Promise<void> {
|
||||
this._columns?.loadData(this.databaseTable);
|
||||
}
|
||||
|
||||
private get databaseName(): string | undefined {
|
||||
return <string>this._databases?.value;
|
||||
}
|
||||
|
||||
private get databaseTable(): DatabaseTable {
|
||||
let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._tables?.value);
|
||||
return {
|
||||
databaseName: this.databaseName,
|
||||
tableName: selectedItem?.tableName,
|
||||
schema: selectedItem?.schema
|
||||
};
|
||||
}
|
||||
|
||||
private get columnNames(): PredictColumn[] | undefined {
|
||||
return this._columns?.data;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../../interfaces';
|
||||
import { ColumnsFilterComponent } from './columnsFilterComponent';
|
||||
import { OutputColumnsComponent } from './outputColumnsComponent';
|
||||
import { PredictParameters } from '../../../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ColumnsSelectionPage extends ModelViewBase implements IPageView, IDataComponent<PredictParameters> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public columnsFilterComponent: ColumnsFilterComponent | undefined;
|
||||
public outputColumnsComponent: OutputColumnsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this);
|
||||
this.columnsFilterComponent.registerComponent(modelBuilder);
|
||||
this.columnsFilterComponent.addComponents(this._formBuilder);
|
||||
this.refresh();
|
||||
|
||||
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
|
||||
this.outputColumnsComponent.registerComponent(modelBuilder);
|
||||
this.outputColumnsComponent.addComponents(this._formBuilder);
|
||||
this.refresh();
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictParameters | undefined {
|
||||
return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ?
|
||||
Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
|
||||
undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._formBuilder) {
|
||||
if (this.columnsFilterComponent) {
|
||||
await this.columnsFilterComponent.refresh();
|
||||
}
|
||||
if (this.outputColumnsComponent) {
|
||||
await this.outputColumnsComponent.refresh();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.columnSelectionPageTitle;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* View to render azure models in a table
|
||||
*/
|
||||
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent;
|
||||
private _selectedColumns: PredictColumn[] = [];
|
||||
private _columns: string[] | undefined;
|
||||
|
||||
/**
|
||||
* Creates a view to render azure models in a table
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._table = this.registerComponent(this._modelBuilder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: [
|
||||
{ // Name
|
||||
displayName: constants.columnDatabase,
|
||||
ariaLabel: constants.columnName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 120,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: constants.inputName,
|
||||
ariaLabel: constants.inputName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
],
|
||||
data: [],
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
return this._table;
|
||||
}
|
||||
|
||||
public get component(): azdata.DeclarativeTableComponent {
|
||||
return this._table;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load data in the component
|
||||
* @param workspaceResource Azure workspace
|
||||
*/
|
||||
public async loadData(table: DatabaseTable): Promise<void> {
|
||||
this._selectedColumns = [];
|
||||
if (this._table) {
|
||||
this._columns = await this.listColumnNames(table);
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (this._columns) {
|
||||
tableData = tableData.concat(this._columns.map(model => this.createTableRow(model)));
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
}
|
||||
|
||||
private createTableRow(column: string): any[] {
|
||||
if (this._modelBuilder) {
|
||||
let selectRowButton = this._modelBuilder.checkBox().withProperties({
|
||||
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: true
|
||||
}).component();
|
||||
let nameInputBox = this._modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: 150
|
||||
}).component();
|
||||
this._selectedColumns.push({ name: column });
|
||||
selectRowButton.onChanged(() => {
|
||||
if (selectRowButton.checked) {
|
||||
if (!this._selectedColumns.find(x => x.name === column)) {
|
||||
this._selectedColumns.push({ name: column });
|
||||
}
|
||||
} else {
|
||||
if (this._selectedColumns.find(x => x.name === column)) {
|
||||
this._selectedColumns = this._selectedColumns.filter(x => x.name !== column);
|
||||
}
|
||||
}
|
||||
});
|
||||
nameInputBox.onTextChanged(() => {
|
||||
let selectedRow = this._selectedColumns.find(x => x.name === column);
|
||||
if (selectedRow) {
|
||||
selectedRow.displayName = nameInputBox.value;
|
||||
}
|
||||
});
|
||||
return [column, nameInputBox, selectRowButton];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictColumn[] | undefined {
|
||||
return this._selectedColumns;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,124 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { PredictColumn } from '../../../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
const componentWidth = 60;
|
||||
export class OutputColumnsComponent extends ModelViewBase implements IDataComponent<PredictColumn[]> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _flex: azdata.FlexContainer | undefined;
|
||||
private _columnName: azdata.InputBoxComponent | undefined;
|
||||
private _columnTypes: azdata.DropDownComponent | undefined;
|
||||
private _dataTypes: string[] = [
|
||||
'int',
|
||||
'nvarchar(MAX)',
|
||||
'varchar(MAX)',
|
||||
'float',
|
||||
'double',
|
||||
'bit'
|
||||
];
|
||||
|
||||
/**
|
||||
* Creates a new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._columnName = modelBuilder.inputBox().withProperties({
|
||||
width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength
|
||||
}).component();
|
||||
this._columnTypes = modelBuilder.dropDown().withProperties({
|
||||
width: componentWidth
|
||||
}).component();
|
||||
|
||||
let flex = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
width: this._columnName.width
|
||||
}).withItems([
|
||||
this._columnName]
|
||||
).component();
|
||||
this._flex = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between',
|
||||
width: this.componentMaxLength
|
||||
}).withItems([
|
||||
flex, this._columnTypes]
|
||||
).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._flex
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.outputColumns,
|
||||
component: this._flex
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.outputColumns,
|
||||
component: this._flex
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the created component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* loads data in the components
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
if (this._columnTypes) {
|
||||
this._columnTypes.values = this._dataTypes;
|
||||
this._columnTypes.value = this._dataTypes[0];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictColumn[] | undefined {
|
||||
return this._columnName && this._columnTypes ? [{
|
||||
name: this._columnName.value || '',
|
||||
dataType: <string>this._columnTypes.value || ''
|
||||
}] : undefined;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
|
||||
import { LocalModelsComponent } from '../localModelsComponent';
|
||||
import { AzureModelsComponent } from '../azureModelsComponent';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { WizardView } from '../../wizardView';
|
||||
import { ModelSourcePage } from '../modelSourcePage';
|
||||
import { ColumnsSelectionPage } from './columnsSelectionPage';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* Wizard to register a model
|
||||
*/
|
||||
export class PredictWizard extends ModelViewBase {
|
||||
|
||||
public modelSourcePage: ModelSourcePage | undefined;
|
||||
//public modelDetailsPage: ModelDetailsPage | undefined;
|
||||
public columnsSelectionPage: ColumnsSelectionPage | undefined;
|
||||
public wizardView: WizardView | undefined;
|
||||
private _parentView: ModelViewBase | undefined;
|
||||
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
root: string,
|
||||
parent?: ModelViewBase) {
|
||||
super(apiWrapper, root);
|
||||
this._parentView = parent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a dialog to manage packages used by notebooks.
|
||||
*/
|
||||
public open(): void {
|
||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
|
||||
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
|
||||
this.wizardView = new WizardView(this._apiWrapper);
|
||||
|
||||
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
|
||||
[this.modelSourcePage,
|
||||
this.columnsSelectionPage]);
|
||||
|
||||
this.mainViewPanel = wizard;
|
||||
wizard.doneButton.label = constants.predictModel;
|
||||
wizard.generateScriptButton.hidden = true;
|
||||
wizard.displayPageTitles = true;
|
||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||
if (validated && pageInfo.newPage === undefined) {
|
||||
wizard.cancelButton.enabled = false;
|
||||
wizard.backButton.enabled = false;
|
||||
await this.predict();
|
||||
wizard.cancelButton.enabled = true;
|
||||
wizard.backButton.enabled = true;
|
||||
if (this._parentView) {
|
||||
this._parentView?.refresh();
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
return validated;
|
||||
});
|
||||
|
||||
wizard.open();
|
||||
}
|
||||
|
||||
public get modelResources(): ModelSourcesComponent | undefined {
|
||||
return this.modelSourcePage?.modelResources;
|
||||
}
|
||||
|
||||
public get localModelsComponent(): LocalModelsComponent | undefined {
|
||||
return this.modelSourcePage?.localModelsComponent;
|
||||
}
|
||||
|
||||
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
||||
return this.modelSourcePage?.azureModelsComponent;
|
||||
}
|
||||
|
||||
private async predict(): Promise<boolean> {
|
||||
try {
|
||||
let modelFilePath: string = '';
|
||||
let registeredModel: RegisteredModel | undefined = undefined;
|
||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||
modelFilePath = this.localModelsComponent.data;
|
||||
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
||||
modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data);
|
||||
} else {
|
||||
registeredModel = this.modelSourcePage?.registeredModelsComponent?.data;
|
||||
}
|
||||
|
||||
await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data);
|
||||
return true;
|
||||
} catch (error) {
|
||||
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh the pages
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.wizardView?.refresh();
|
||||
}
|
||||
}
|
||||
@@ -5,11 +5,11 @@
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import * as constants from '../../common/constants';
|
||||
import { ModelViewBase, RegisterModelEventName } from './modelViewBase';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { CurrentModelsTable } from './currentModelsTable';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { IPageView } from '../interfaces';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { IPageView } from '../../interfaces';
|
||||
|
||||
/**
|
||||
* View to render current registered models
|
||||
@@ -33,28 +33,21 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
||||
* @param modelBuilder register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._dataTable = new CurrentModelsTable(this._apiWrapper, modelBuilder, this);
|
||||
this._dataTable = new CurrentModelsTable(this._apiWrapper, this);
|
||||
this._dataTable.registerComponent(modelBuilder);
|
||||
this._tableComponent = this._dataTable.component;
|
||||
|
||||
let registerButton = modelBuilder.button().withProperties({
|
||||
label: constants.registerModelTitle,
|
||||
width: this.buttonMaxLength
|
||||
}).component();
|
||||
registerButton.onDidClick(async () => {
|
||||
await this.sendDataRequest(RegisterModelEventName);
|
||||
});
|
||||
let formModelBuilder = modelBuilder.formContainer();
|
||||
|
||||
let formModel = modelBuilder.formContainer()
|
||||
.withFormItems([{
|
||||
title: '',
|
||||
component: registerButton
|
||||
}, {
|
||||
if (this._tableComponent) {
|
||||
formModelBuilder.addFormItem({
|
||||
component: this._tableComponent,
|
||||
title: ''
|
||||
}]).component();
|
||||
});
|
||||
}
|
||||
|
||||
this._loader = modelBuilder.loadingComponent()
|
||||
.withItem(formModel)
|
||||
.withItem(formModelBuilder.component())
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
@@ -4,24 +4,26 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
|
||||
/**
|
||||
* View to render registered models table
|
||||
*/
|
||||
export class CurrentModelsTable extends ModelViewBase {
|
||||
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<RegisteredModel> {
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent;
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
private _modelBuilder: azdata.ModelBuilder | undefined;
|
||||
private _selectedModel: any;
|
||||
|
||||
/**
|
||||
* Creates new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._table = this.registerComponent(this._modelBuilder);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -29,6 +31,7 @@ export class CurrentModelsTable extends ModelViewBase {
|
||||
* @param modelBuilder register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
||||
this._modelBuilder = modelBuilder;
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
@@ -92,10 +95,23 @@ export class CurrentModelsTable extends ModelViewBase {
|
||||
return this._table;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.component) {
|
||||
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component });
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.component) {
|
||||
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.DeclarativeTableComponent {
|
||||
public get component(): azdata.DeclarativeTableComponent | undefined {
|
||||
return this._table;
|
||||
}
|
||||
|
||||
@@ -103,38 +119,45 @@ export class CurrentModelsTable extends ModelViewBase {
|
||||
* Loads the data in the component
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
let models: RegisteredModel[] | undefined;
|
||||
if (this._table) {
|
||||
let models: RegisteredModel[] | undefined;
|
||||
|
||||
models = await this.listModels();
|
||||
let tableData: any[][] = [];
|
||||
models = await this.listModels();
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (models) {
|
||||
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
|
||||
if (models) {
|
||||
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
|
||||
private createTableRow(model: RegisteredModel): any[] {
|
||||
if (this._modelBuilder) {
|
||||
let editLanguageButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
title: constants.deleteTitle,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/edit.svg')
|
||||
},
|
||||
let selectModelButton = this._modelBuilder.radioButton().withProperties({
|
||||
name: 'amlModel',
|
||||
value: model.id,
|
||||
width: 15,
|
||||
height: 15
|
||||
height: 15,
|
||||
checked: false
|
||||
}).component();
|
||||
editLanguageButton.onDidClick(() => {
|
||||
selectModelButton.onDidClick(() => {
|
||||
this._selectedModel = model;
|
||||
});
|
||||
return [model.artifactName, model.title, model.created, editLanguageButton];
|
||||
return [model.artifactName, model.title, model.created, selectModelButton];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): RegisteredModel | undefined {
|
||||
return this._selectedModel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
@@ -4,15 +4,15 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
|
||||
import { LocalModelsComponent } from './localModelsComponent';
|
||||
import { AzureModelsComponent } from './azureModelsComponent';
|
||||
import * as constants from '../../common/constants';
|
||||
import { WizardView } from '../wizardView';
|
||||
import { ModelSourcePage } from './modelSourcePage';
|
||||
import { ModelDetailsPage } from './modelDetailsPage';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
|
||||
import { LocalModelsComponent } from '../localModelsComponent';
|
||||
import { AzureModelsComponent } from '../azureModelsComponent';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { WizardView } from '../../wizardView';
|
||||
import { ModelSourcePage } from '../modelSourcePage';
|
||||
import { ModelDetailsPage } from '../modelDetailsPage';
|
||||
|
||||
/**
|
||||
* Wizard to register a model
|
||||
@@ -47,19 +47,20 @@ export class RegisterModelWizard extends ModelViewBase {
|
||||
wizard.generateScriptButton.hidden = true;
|
||||
wizard.displayPageTitles = true;
|
||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||
if (pageInfo.newPage === undefined) {
|
||||
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||
if (validated && pageInfo.newPage === undefined) {
|
||||
wizard.cancelButton.enabled = false;
|
||||
wizard.backButton.enabled = false;
|
||||
await this.registerModel();
|
||||
let result = await this.registerModel();
|
||||
wizard.cancelButton.enabled = true;
|
||||
wizard.backButton.enabled = true;
|
||||
if (this._parentView) {
|
||||
this._parentView?.refresh();
|
||||
await this._parentView?.refresh();
|
||||
}
|
||||
return true;
|
||||
return result;
|
||||
|
||||
}
|
||||
return true;
|
||||
return validated;
|
||||
});
|
||||
|
||||
wizard.open();
|
||||
@@ -5,10 +5,10 @@
|
||||
|
||||
import { CurrentModelsPage } from './currentModelsPage';
|
||||
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { DialogView } from '../dialogView';
|
||||
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { DialogView } from '../../dialogView';
|
||||
|
||||
/**
|
||||
* Dialog to render registered model views
|
||||
@@ -31,7 +31,13 @@ export class RegisteredModelsDialog extends ModelViewBase {
|
||||
|
||||
this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this);
|
||||
|
||||
let registerModelButton = this._apiWrapper.createButton(constants.registerModelTitle);
|
||||
registerModelButton.onClick(async () => {
|
||||
await this.sendDataRequest(RegisterModelEventName);
|
||||
});
|
||||
|
||||
let dialog = this.dialogView.createDialog('', [this.currentLanguagesTab]);
|
||||
dialog.customButtons = [registerModelButton];
|
||||
this.mainViewPanel = dialog;
|
||||
dialog.okButton.hidden = true;
|
||||
dialog.cancelButton.label = constants.extLangDoneButtonText;
|
||||
@@ -4,6 +4,8 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
import * as constants from '../common/constants';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as path from 'path';
|
||||
@@ -21,8 +23,7 @@ export interface CallbackEventArgs {
|
||||
}
|
||||
|
||||
export const CallEventNamePostfix = 'Callback';
|
||||
export const LocalFileEventName = 'localFile';
|
||||
export const LocalFolderEventName = 'localFolder';
|
||||
export const LocalPathsEventName = 'localPaths';
|
||||
|
||||
/**
|
||||
* Base class for views
|
||||
@@ -51,7 +52,7 @@ export abstract class ViewBase extends EventEmitterCollection {
|
||||
}
|
||||
|
||||
protected getEventNames(): string[] {
|
||||
return [LocalFolderEventName, LocalFileEventName];
|
||||
return [LocalPathsEventName];
|
||||
}
|
||||
|
||||
protected getCallbackEventNames(): string[] {
|
||||
@@ -118,12 +119,8 @@ export abstract class ViewBase extends EventEmitterCollection {
|
||||
});
|
||||
}
|
||||
|
||||
public async getLocalFilePath(): Promise<string> {
|
||||
return await this.sendDataRequest(LocalFileEventName);
|
||||
}
|
||||
|
||||
public async getLocalFolderPath(): Promise<string> {
|
||||
return await this.sendDataRequest(LocalFolderEventName);
|
||||
public async getLocalPaths(options: vscode.OpenDialogOptions): Promise<string[]> {
|
||||
return await this.sendDataRequest(LocalPathsEventName, options);
|
||||
}
|
||||
|
||||
public async getLocationTitle(): Promise<string> {
|
||||
@@ -174,12 +171,12 @@ export abstract class ViewBase extends EventEmitterCollection {
|
||||
}
|
||||
|
||||
public showErrorMessage(message: string, error?: any): void {
|
||||
this.showMessage(`${message} ${constants.getErrorMessage(error)}`, azdata.window.MessageLevel.Error);
|
||||
this.showMessage(`${message} ${error ? constants.getErrorMessage(error) : ''}`, azdata.window.MessageLevel.Error);
|
||||
}
|
||||
|
||||
private showMessage(message: string, level: azdata.window.MessageLevel): void {
|
||||
if (this._mainViewPanel) {
|
||||
this._mainViewPanel.message = {
|
||||
if (this.mainViewPanel) {
|
||||
this.mainViewPanel.message = {
|
||||
text: message,
|
||||
level: level
|
||||
};
|
||||
|
||||
@@ -45,6 +45,19 @@ export class WizardView extends MainViewBase {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds wizard page
|
||||
* @param page page
|
||||
* @param index page index
|
||||
*/
|
||||
public removeWizardPage(page: IPageView, index: number): void {
|
||||
if (this._wizard && this._pages[index] === page) {
|
||||
this._pages = this._pages.splice(index);
|
||||
this._wizard.removePage(index);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param title Creates anew wizard
|
||||
@@ -57,9 +70,21 @@ export class WizardView extends MainViewBase {
|
||||
this._wizard.onPageChanged(async (info) => {
|
||||
this.onWizardPageChanged(info);
|
||||
});
|
||||
|
||||
return this._wizard;
|
||||
}
|
||||
|
||||
public async validate(pageInfo: azdata.window.WizardPageChangeInfo): Promise<boolean> {
|
||||
if (pageInfo.lastPage !== undefined) {
|
||||
let idxLast = pageInfo.lastPage;
|
||||
let lastPage = this._pages[idxLast];
|
||||
if (lastPage && lastPage.validate) {
|
||||
return await lastPage.validate();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
|
||||
let idxLast = pageInfo.lastPage;
|
||||
let lastPage = this._pages[idxLast];
|
||||
@@ -73,4 +98,8 @@ export class WizardView extends MainViewBase {
|
||||
page.onEnter();
|
||||
}
|
||||
}
|
||||
|
||||
public get wizard(): azdata.window.Wizard | undefined {
|
||||
return this._wizard;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user