Machine Learning Services Extension - Predict wizard (#9450)

*MLS extension - Added predict wizard
This commit is contained in:
Leila Lali
2020-03-09 15:40:05 -07:00
committed by GitHub
parent b017634431
commit 3be3563b0d
37 changed files with 1501 additions and 219 deletions

View File

@@ -26,6 +26,7 @@
"modelManagement": {
"registeredModelsDatabaseName": "MlFlowDB",
"registeredModelsTableName": "artifacts",
"registeredModelsTableSchemaName": "dbo",
"amlModelManagementUrl": "modelmanagement.azureml.net",
"amlExperienceUrl": "experiments.azureml.net",
"amlApiVersion": "2018-11-19",

View File

@@ -57,6 +57,10 @@
"command": "mls.command.managePackages",
"title": "%mls.command.managePackages%"
},
{
"command": "mls.command.predictModel",
"title": "%mls.command.predictModel%"
},
{
"command": "mls.command.manageModels",
"title": "%mls.command.manageModels%"
@@ -110,7 +114,7 @@
"mls.command.managePackages",
"mls.command.manageLanguages",
"mls.command.manageModels",
"mls.command.registerModel"
"mls.command.predictModel"
]
}
},

View File

@@ -7,8 +7,9 @@
"title.endpoints": "Endpoints",
"mls.command.managePackages": "Manage Packages in SQL Server",
"mls.command.manageLanguages": "Manage External Languages",
"mls.command.manageModels": "Manage Models",
"mls.command.registerModel": "Register Model",
"mls.command.predictModel": "Make prediction",
"mls.command.manageModels": "Manage models",
"mls.command.registerModel": "Register model",
"mls.command.odbcdriver": "Install ODBC Driver for SQL Server",
"mls.command.mlsdocs": "Machine Learning Services Documentation",
"mls.configuration.title": "Machine Learning Services configurations",

View File

@@ -105,4 +105,28 @@ export class ApiWrapper {
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
return vscode.window.showQuickPick(items, options, token);
}
public listDatabases(connectionId: string): Thenable<string[]> {
return azdata.connection.listDatabases(connectionId);
}
public openTextDocument(options?: { language?: string; content?: string; }): Thenable<vscode.TextDocument> {
return vscode.workspace.openTextDocument(options);
}
public connect(fileUri: string, connectionId: string): Thenable<void> {
return azdata.queryeditor.connect(fileUri, connectionId);
}
public runQuery(fileUri: string, options?: Map<string, string>, runCurrentQuery?: boolean): void {
azdata.queryeditor.runQuery(fileUri, options, runCurrentQuery);
}
public showTextDocument(uri: vscode.Uri, options?: vscode.TextDocumentShowOptions): Thenable<vscode.TextEditor> {
return vscode.window.showTextDocument(uri, options);
}
public createButton(label: string, position?: azdata.window.DialogButtonPosition): azdata.window.Button {
return azdata.window.createButton(label, position);
}
}

View File

@@ -24,6 +24,7 @@ export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups';
// Tasks, commands
//
export const mlManageLanguagesCommand = 'mls.command.manageLanguages';
export const mlsPredictModelCommand = 'mls.command.predictModel';
export const mlManageModelsCommand = 'mls.command.manageModels';
export const mlRegisterModelCommand = 'mls.command.registerModel';
export const mlManagePackagesCommand = 'mls.command.managePackages';
@@ -116,6 +117,12 @@ export const modelCreated = localize('models.created', "Date Created");
export const modelVersion = localize('models.version', "Version");
export const browseModels = localize('models.browseButton', "...");
export const azureAccount = localize('models.azureAccount', "Azure account");
export const columnDatabase = localize('predict.columnDatabase', "Database");
export const columnTable = localize('predict.columnTable', "Table");
export const inputColumns = localize('predict.inputColumns', "Input columns");
export const outputColumns = localize('predict.outputColumns', "Output column");
export const columnName = localize('predict.columnName', "Name");
export const inputName = localize('predict.inputName', "Input Name");
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
export const azureGroup = localize('models.azureGroup', "Azure resource group");
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
@@ -125,18 +132,25 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Ender model source details");
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Select input columns");
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
export const azureRegisterModel = localize('models.azureRegisterModel', "Register");
export const predictModel = localize('models.predictModel', "Predict");
export const registerModelTitle = localize('models.RegisterWizard', "Register model");
export const makePredictionTitle = localize('models.makePredictionTitle', "Make prediction");
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
export const localModelSource = localize('models.localModelSource', "Upload file");
export const azureModelSource = localize('models.azureModelSource', "Import from AzureML registry");
export const registeredModelsSource = localize('models.registeredModelsSource', "Select managed models");
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict");
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");

View File

@@ -163,4 +163,21 @@ export class QueryRunner {
}
return result;
}
/**
* Executes the query but doesn't fail it is fails
* @param connection SQL connection
* @param query query to run
*/
public async safeRunQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this.runQuery(connection, query);
} catch (error) {
console.log(error);
return undefined;
}
}
}

View File

@@ -11,6 +11,7 @@ import * as fs from 'fs';
import * as constants from '../common/constants';
import { promisify } from 'util';
import { ApiWrapper } from './apiWrapper';
import { Config } from '../configurations/config';
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
let tempFilePath: string = '';
@@ -25,6 +26,11 @@ export async function execCommandOnTempFile<T>(content: string, command: (filePa
}
}
export async function readFileInHex(filePath: string): Promise<string> {
let buffer = await fs.promises.readFile(filePath);
return `0X${buffer.toString('hex')}`;
}
export async function exists(path: string): Promise<boolean> {
return promisify(fs.exists)(path);
}
@@ -109,8 +115,8 @@ export function isWindows(): boolean {
* ' => ''
* @param value The string to escape
*/
export function doubleEscapeSingleQuotes(value: string): string {
return value.replace(/'/g, '\'\'');
export function doubleEscapeSingleQuotes(value: string | undefined): string {
return value ? value.replace(/'/g, '\'\'') : '';
}
/**
@@ -118,8 +124,8 @@ export function doubleEscapeSingleQuotes(value: string): string {
* ' => ''
* @param value The string to escape
*/
export function doubleEscapeSingleBrackets(value: string): string {
return value.replace(/\[/g, '[[').replace(/\]/g, ']]');
export function doubleEscapeSingleBrackets(value: string | undefined): string {
return value ? value.replace(/\[/g, '[[').replace(/\]/g, ']]') : '';
}
/**
@@ -176,3 +182,48 @@ export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Pr
return choices[result.label] || false;
}
export function makeLinuxPath(filePath: string): string {
const parts = filePath.split('\\');
return parts.join('/');
}
/**
*
* @param currentDb Wraps the given script with database switch scripts
* @param databaseName
* @param script
*/
export function getScriptWithDBChange(currentDb: string, databaseName: string, script: string): string {
if (!currentDb) {
currentDb = 'master';
}
let escapedDbName = doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = doubleEscapeSingleBrackets(currentDb);
return `
USE [${escapedDbName}]
${script}
USE [${escapedCurrentDbName}]
`;
}
/**
* Returns full name of model registration table
* @param config config
*/
export function getRegisteredModelsThreePartsName(config: Config) {
const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName);
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
return `[${dbName}].${schema}.[${tableName}]`;
}
/**
* Returns full name of model registration table
* @param config config object
*/
export function getRegisteredModelsTowPartsName(config: Config) {
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
return `[${schema}].[${tableName}]`;
}

View File

@@ -82,6 +82,13 @@ export class Config {
return this._configValues.modelManagement.registeredModelsTableName;
}
/**
* Returns registered models table schema name
*/
public get registeredModelTableSchemaName(): string {
return this._configValues.modelManagement.registeredModelsTableSchemaName;
}
/**
* Returns registered models table name
*/

View File

@@ -22,6 +22,7 @@ import { ModelManagementController } from '../views/models/modelManagementContro
import { RegisteredModelService } from '../modelManagement/registeredModelService';
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
import { ModelImporter } from '../modelManagement/modelImporter';
import { PredictService } from '../prediction/predictService';
/**
* The main controller class that initializes the extension
@@ -109,7 +110,9 @@ export default class MainController implements vscode.Disposable {
//
let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, azureModelsService, registeredModelService);
let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config);
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
azureModelsService, registeredModelService, predictService);
this._apiWrapper.registerCommand(constants.mlManageLanguagesCommand, (async () => {
await languageController.manageLanguages();
@@ -120,6 +123,9 @@ export default class MainController implements vscode.Disposable {
this._apiWrapper.registerCommand(constants.mlRegisterModelCommand, (async () => {
await modelManagementController.registerModel();
}));
this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => {
await modelManagementController.predictModel();
}));
this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => {
await packageManager.installDependencies();
}));
@@ -135,6 +141,9 @@ export default class MainController implements vscode.Disposable {
this._apiWrapper.registerTaskHandler(constants.mlRegisterModelCommand, async () => {
await modelManagementController.registerModel();
});
this._apiWrapper.registerTaskHandler(constants.mlsPredictModelCommand, async () => {
await modelManagementController.predictModel();
});
this._apiWrapper.registerTaskHandler(constants.mlOdbcDriverCommand, async () => {
await this.serverConfigManager.openOdbcDriverDocuments();
});

View File

@@ -48,13 +48,19 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
/**
* An interface representing registered model
*/
export interface RegisteredModel {
id?: number,
artifactName?: string,
title?: string,
created?: string,
version?: string
description?: string
export interface RegisteredModel extends RegisteredModelDetails {
id: number;
artifactName: string;
}
/**
* An interface representing registered model
*/
export interface RegisteredModelDetails {
title: string;
created?: string;
version?: string;
description?: string;
}
/**

View File

@@ -12,6 +12,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as utils from '../common/utils';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
import * as os from 'os';
/**
* Service to import model to database
@@ -39,8 +40,8 @@ export class ModelImporter {
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
const parts = modelFolderPath.split('\\');
modelFolderPath = parts.join('/');
let home = utils.makeLinuxPath(os.homedir());
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
@@ -51,9 +52,12 @@ export class ModelImporter {
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
let scripts: string[] = [
'import mlflow.onnx',
`tracking_uri = "file://${home}/mlruns"`,
'print(tracking_uri)',
'import onnx',
'from mlflow.tracking.client import MlflowClient',
`onx = onnx.load("${modelFolderPath}")`,
`mlflow.set_tracking_uri(tracking_uri)`,
'client = MlflowClient()',
`exp_name = "${experimentId}"`,
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,

View File

@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
import { RegisteredModel } from './interfaces';
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
import { ModelImporter } from './modelImporter';
import * as constants from '../common/constants';
@@ -32,7 +32,10 @@ export class RegisteredModelService {
let connection = await this.getCurrentConnection();
let list: RegisteredModel[] = [];
if (connection) {
let result = await this.runRegisteredModelsListQuery(connection);
let query = this.getConfigureQuery(connection.databaseName);
await this._queryRunner.safeRunQuery(connection, query);
query = this.registeredModelsQuery();
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push(this.loadModelData(row));
@@ -57,7 +60,8 @@ export class RegisteredModelService {
let connection = await this.getCurrentConnection();
let updatedModel: RegisteredModel | undefined = undefined;
if (connection) {
let result = await this.runUpdateModelQuery(connection, model);
const query = this.getUpdateModelScript(connection.databaseName, model);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
const row = result.rows[0];
updatedModel = this.loadModelData(row);
@@ -66,7 +70,7 @@ export class RegisteredModelService {
return updatedModel;
}
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
let connection = await this.getCurrentConnection();
if (connection) {
let currentModels = await this.getRegisteredModels();
@@ -93,35 +97,14 @@ export class RegisteredModelService {
return await this._apiWrapper.getCurrentConnection();
}
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
} catch {
return undefined;
}
private getConfigureQuery(currentDatabaseName: string): string {
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
}
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
} catch {
return undefined;
}
}
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
private registeredModelsQuery(): string {
return `
${this.configureTable(databaseName, tableName)}
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created
FROM [${escapedDbName}].dbo.[${escapedTableName}]
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
@@ -133,52 +116,74 @@ export class RegisteredModelService {
* @param databaseName
* @param tableName
*/
private configureTable(databaseName: string, tableName: string): string {
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
private configureTable(): string {
let databaseName = this._config.registeredModelDatabaseName;
let tableName = this._config.registeredModelTableName;
let schemaName = this._config.registeredModelTableSchemaName;
return `
USE [${escapedDbName}]
IF NOT EXISTS (
SELECT [name]
FROM sys.databases
WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}'
)
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}]
GO
USE [${utils.doubleEscapeSingleBrackets(databaseName)}]
IF EXISTS
( SELECT [name]
FROM sys.tables
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
( SELECT [t.name], [s.name]
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name')
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version')
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created')
BEGIN
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
END
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description')
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL
END
Else
BEGIN
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
[artifact_name] [varchar](256) NOT NULL,
[group_path] [varchar](256) NOT NULL,
[artifact_content] [varbinary](max) NOT NULL,
[artifact_initial_size] [bigint] NULL,
[name] [varchar](256) NULL,
[version] [varchar](256) NULL,
[created] [datetime] NULL,
[description] [varchar](256) NULL,
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
(
[artifact_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
END
`;
}
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
let updateScript = `
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
SET
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
WHERE artifact_id = ${model.id}`;
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
return `
USE [${escapedDbName}]
UPDATE ${escapedTableName}
SET
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
WHERE artifact_id = ${model.id};
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
SELECT artifact_id, artifact_name, name, description, version, created
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
WHERE artifact_id = ${model.id};
`;
}

View File

@@ -0,0 +1,24 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export interface PredictColumn {
name: string;
dataType?: string;
displayName?: string;
}
export interface DatabaseTable {
databaseName: string | undefined;
tableName: string | undefined;
schema: string | undefined
}
export interface PredictInputParameters extends DatabaseTable {
inputColumns: PredictColumn[] | undefined
}
export interface PredictParameters extends PredictInputParameters {
outputColumns: PredictColumn[] | undefined
}

View File

@@ -0,0 +1,203 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import * as utils from '../common/utils';
import { RegisteredModel } from '../modelManagement/interfaces';
import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces';
import { Config } from '../configurations/config';
/**
* Service to make prediction
*/
export class PredictService {
/**
* Creates new instance
*/
constructor(
private _apiWrapper: ApiWrapper,
private _queryRunner: QueryRunner,
private _config: Config) {
}
/**
* Returns the list of databases
*/
public async getDatabaseList(): Promise<string[]> {
let connection = await this.getCurrentConnection();
if (connection) {
return await this._apiWrapper.listDatabases(connection.connectionId);
}
return [];
}
/**
* Generates prediction script given model info and predict parameters
* @param predictParams predict parameters
* @param registeredModel model parameters
*/
public async generatePredictScript(
predictParams: PredictParameters,
registeredModel: RegisteredModel | undefined,
filePath: string | undefined
): Promise<string> {
let connection = await this.getCurrentConnection();
let query = '';
if (registeredModel && registeredModel.id) {
query = this.getPredictScriptWithModelId(
registeredModel.id,
predictParams.inputColumns || [],
predictParams.outputColumns || [],
predictParams);
} else if (filePath) {
let modelBytes = await utils.readFileInHex(filePath || '');
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
predictParams.outputColumns || [],
predictParams);
}
let document = await this._apiWrapper.openTextDocument({
language: 'sql',
content: query
});
await this._apiWrapper.showTextDocument(document.uri);
await this._apiWrapper.connect(document.uri.toString(), connection.connectionId);
this._apiWrapper.runQuery(document.uri.toString(), undefined, false);
return query;
}
/**
* Returns list of tables given database name
* @param databaseName database name
*/
public async getTableList(databaseName: string): Promise<DatabaseTable[]> {
let connection = await this.getCurrentConnection();
let list: DatabaseTable[] = [];
if (connection) {
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push({
databaseName: databaseName,
tableName: row[0].displayValue,
schema: row[1].displayValue
});
});
}
}
return list;
}
/**
*Returns list of column names of a database
* @param databaseTable table info
*/
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<string[]> {
let connection = await this.getCurrentConnection();
let list: string[] = [];
if (connection && databaseTable.databaseName) {
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push(row[0].displayValue);
});
}
}
return list;
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
private getTableColumnsScript(databaseTable: DatabaseTable): string {
return `
SELECT COLUMN_NAME,*
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
`;
}
private getTablesScript(databaseName: string): string {
return `
SELECT TABLE_NAME,TABLE_SCHEMA
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
`;
}
private getPredictScriptWithModelId(
modelId: number,
columns: PredictColumn[],
outputColumns: PredictColumn[],
databaseNameTable: DatabaseTable): string {
return `
DECLARE @model VARBINARY(max) = (
SELECT artifact_content
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
WHERE artifact_id = ${modelId}
);
WITH predict_input
AS (
SELECT TOP 1000
${this.getColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = @model, DATA = predict_input)
WITH (
${this.getColumnTypes(outputColumns)}
) AS p
`;
}
private getPredictScriptWithModelBytes(
modelBytes: string,
columns: PredictColumn[],
outputColumns: PredictColumn[],
databaseNameTable: DatabaseTable): string {
return `
WITH predict_input
AS (
SELECT TOP 1000
${this.getColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
WITH (
${this.getColumnTypes(outputColumns)}
) AS p
`;
}
private getColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`;
}).join(',\n');
}
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`;
}).join(',\n');
}
private getColumnTypes(columns: PredictColumn[]) {
return columns.map(c => {
return `${c.name} ${c.dataType}`;
}).join(',\n');
}
}

View File

@@ -13,7 +13,7 @@ import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { RegisterModelWizard } from '../../../views/models/registerModelWizard';
import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard';
describe('Register Model Wizard', () => {
it('Should create view components successfully ', async function (): Promise<void> {
@@ -74,7 +74,8 @@ describe('Register Model Wizard', () => {
let localModels: RegisteredModel[] = [
{
id: 1,
artifactName: 'model'
artifactName: 'model',
title: 'model'
}
];
view.on(ListModelsEventName, () => {

View File

@@ -6,7 +6,7 @@
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { RegisteredModelsDialog } from '../../../views/models/registeredModelsDialog';
import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog';
import { ListModelsEventName } from '../../../views/models/modelViewBase';
import { RegisteredModel } from '../../../modelManagement/interfaces';
import { ViewBase } from '../../../views/viewBase';
@@ -30,7 +30,8 @@ describe('Registered Models Dialog', () => {
let models: RegisteredModel[] = [
{
id: 1,
artifactName: 'model'
artifactName: 'model',
title: ''
}
];
view.on(ListModelsEventName, () => {

View File

@@ -246,6 +246,7 @@ export function createViewContext(): ViewTestContext {
modelView: undefined!,
valid: true
};
apiWrapper.setup(x => x.createButton(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => dialogButton);
apiWrapper.setup(x => x.createTab(TypeMoq.It.isAny())).returns(() => tab);
apiWrapper.setup(x => x.createWizard(TypeMoq.It.isAny())).returns(() => wizard);
apiWrapper.setup(x => x.createWizardPage(TypeMoq.It.isAny())).returns(() => wizardPage);

View File

@@ -3,7 +3,9 @@
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ViewBase, LocalFileEventName, LocalFolderEventName } from './viewBase';
import * as vscode from 'vscode';
import { ViewBase, LocalPathsEventName } from './viewBase';
import { ApiWrapper } from '../common/apiWrapper';
/**
@@ -36,11 +38,8 @@ export abstract class ControllerBase {
* @param view view
*/
public registerEvents(view: ViewBase): void {
view.on(LocalFileEventName, async () => {
await this.executeAction(view, LocalFileEventName, this.getLocalFilePath, this._apiWrapper);
});
view.on(LocalFolderEventName, async () => {
await this.executeAction(view, LocalFolderEventName, this.getLocalFolderPath, this._apiWrapper);
view.on(LocalPathsEventName, async (args) => {
await this.executeAction(view, LocalPathsEventName, this.getLocalPaths, this._apiWrapper, args);
});
}
@@ -48,25 +47,8 @@ export abstract class ControllerBase {
* Returns local file path picked by the user
* @param apiWrapper apiWrapper
*/
public async getLocalFilePath(apiWrapper: ApiWrapper): Promise<string> {
let result = await apiWrapper.showOpenDialog({
canSelectFiles: true,
canSelectFolders: false,
canSelectMany: false
});
return result && result.length > 0 ? result[0].fsPath : '';
}
/**
* Returns local folder path picked by the user
* @param apiWrapper apiWrapper
*/
public async getLocalFolderPath(apiWrapper: ApiWrapper): Promise<string> {
let result = await apiWrapper.showOpenDialog({
canSelectFiles: false,
canSelectFolders: true,
canSelectMany: false
});
return result && result.length > 0 ? result[0].fsPath : '';
public async getLocalPaths(apiWrapper: ApiWrapper, options: vscode.OpenDialogOptions): Promise<string[]> {
let result = await apiWrapper.showOpenDialog(options);
return result ? result?.map(x => x.fsPath) : [];
}
}

View File

@@ -16,6 +16,7 @@ export interface IPageView {
component: azdata.Component | undefined;
onEnter?: () => Promise<void>;
onLeave?: () => Promise<void>;
validate?: () => Promise<boolean>;
refresh: () => Promise<void>;
viewPanel: azdata.window.ModelViewPanel | undefined;
title: string;
@@ -32,3 +33,4 @@ export interface AzureModelResource extends AzureWorkspaceResource {
model?: WorkspaceModel;
}

View File

@@ -4,6 +4,8 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
@@ -43,9 +45,17 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen
}
}).component();
this._localBrowse.onDidClick(async () => {
const filePath = await this.getLocalFilePath();
let options: vscode.OpenDialogOptions = {
canSelectFiles: true,
canSelectFolders: false,
canSelectMany: false,
filters: { 'ONNX File': ['onnx'] }
};
const filePaths = await this.getLocalPaths(options);
if (this._localPath) {
this._localPath.value = filePath;
this._localPath.value = filePaths && filePaths.length > 0 ? filePaths[0] : '';
}
});

View File

@@ -8,12 +8,12 @@ import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
import { RegisteredModel } from '../../modelManagement/interfaces';
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
/**
* View to pick local models file
*/
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModel> {
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModelDetails> {
private _form: azdata.FormContainer | undefined;
private _nameComponent: azdata.InputBoxComponent | undefined;
@@ -81,9 +81,9 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone
/**
* Returns selected data
*/
public get data(): RegisteredModel {
public get data(): RegisteredModelDetails {
return {
title: this._nameComponent?.value,
title: this._nameComponent?.value || '',
description: this._descriptionComponent?.value
};
}

View File

@@ -9,12 +9,12 @@ import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { ModelDetailsComponent } from './modelDetailsComponent';
import { RegisteredModel } from '../../modelManagement/interfaces';
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
/**
* View to pick model details
*/
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModel> {
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModelDetails> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
@@ -43,7 +43,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
/**
* Returns selected data
*/
public get data(): RegisteredModel | undefined {
public get data(): RegisteredModelDetails | undefined {
return this.modelDetails?.data;
}
@@ -66,4 +66,13 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
public get title(): string {
return constants.modelDetailsPageTitle;
}
public validate(): Promise<boolean> {
if (this.data && this.data.title) {
return Promise.resolve(true);
} else {
this.showErrorMessage(constants.modelNameRequiredError);
return Promise.resolve(false);
}
}
}

View File

@@ -9,14 +9,23 @@ import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
import { RegisteredModelsDialog } from './registeredModelsDialog';
import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName } from './modelViewBase';
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { RegisterModelWizard } from './registerModelWizard';
import { RegisterModelWizard } from './registerModels/registerModelWizard';
import * as fs from 'fs';
import * as constants from '../../common/constants';
import { PredictWizard } from './prediction/predictWizard';
import { AzureModelResource } from '../interfaces';
import { PredictService } from '../../prediction/predictService';
/**
* Model management UI controller
@@ -30,7 +39,8 @@ export class ModelManagementController extends ControllerBase {
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
private _registeredModelService: RegisteredModelService) {
private _registeredModelService: RegisteredModelService,
private _predictService: PredictService) {
super(apiWrapper);
}
@@ -56,6 +66,23 @@ export class ModelManagementController extends ControllerBase {
return view;
}
/**
* Opens the wizard for prediction
*/
public async predictModel(): Promise<ModelViewBase> {
let view = new PredictWizard(this._apiWrapper, this._root);
this.registerEvents(view);
// Open view
//
view.open();
await view.refresh();
return view;
}
/**
* Register events in the main view
* @param view main view
@@ -102,6 +129,28 @@ export class ModelManagementController extends ControllerBase {
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
});
view.on(DownloadAzureModelEventName, async (arg) => {
let registerArgs = <AzureModelResource>arg;
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
});
view.on(ListDatabaseNamesEventName, async () => {
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
});
view.on(ListTableNamesEventName, async (arg) => {
let dbName = <string>arg;
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
});
view.on(ListColumnNamesEventName, async (arg) => {
let tableColumnsArgs = <DatabaseTable>arg;
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
tableColumnsArgs);
});
view.on(PredictModelEventName, async (arg) => {
let predictArgs = <PredictModelEventArgs>arg;
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
predictArgs, predictArgs.model, predictArgs.filePath);
});
view.on(SourceModelSelectedEventName, () => {
view.refresh();
});
@@ -158,7 +207,7 @@ export class ModelManagementController extends ControllerBase {
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise<void> {
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
if (filePath) {
await service.registerLocalModel(filePath, details);
} else {
@@ -175,7 +224,7 @@ export class ModelManagementController extends ControllerBase {
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined,
details: RegisteredModel | undefined): Promise<void> {
details: RegisteredModelDetails | undefined): Promise<void> {
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
throw Error(constants.invalidAzureResourceError);
}
@@ -188,4 +237,47 @@ export class ModelManagementController extends ControllerBase {
throw Error(constants.invalidModelToRegisterError);
}
}
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
return await predictService.getDatabaseList();
}
public async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
return await predictService.getTableList(databaseName);
}
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
return await predictService.getTableColumnsList(databaseTable);
}
private async generatePredictScript(
predictService: PredictService,
predictParams: PredictParameters,
registeredModel: RegisteredModel | undefined,
filePath: string | undefined
): Promise<string> {
if (!predictParams) {
throw Error(constants.invalidModelToPredictError);
}
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
return result;
}
private async downloadAzureModel(
azureService: AzureModelRegistryService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined): Promise<string> {
if (!account || !subscription || !resourceGroup || !workspace || !model) {
throw Error(constants.invalidAzureResourceError);
}
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
return filePath;
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
}

View File

@@ -11,6 +11,7 @@ import { IPageView, IDataComponent } from '../interfaces';
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import { CurrentModelsTable } from './registerModels/currentModelsTable';
/**
* View to pick model source
@@ -22,8 +23,9 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
public modelResources: ModelSourcesComponent | undefined;
public localModelsComponent: LocalModelsComponent | undefined;
public azureModelsComponent: AzureModelsComponent | undefined;
public registeredModelsComponent: CurrentModelsTable | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
}
@@ -34,13 +36,15 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
this.modelResources.registerComponent(modelBuilder);
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
this.localModelsComponent.registerComponent(modelBuilder);
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
this.azureModelsComponent.registerComponent(modelBuilder);
this.modelResources.addComponents(this._formBuilder);
this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this);
this.registeredModelsComponent.registerComponent(modelBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
@@ -66,19 +70,29 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.modelResources && this.modelResources.data === ModelSourceType.Local) {
if (this.localModelsComponent && this.azureModelsComponent) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.azureModelsComponent.removeComponents(this._formBuilder);
this.registeredModelsComponent.removeComponents(this._formBuilder);
this.localModelsComponent.addComponents(this._formBuilder);
await this.localModelsComponent.refresh();
}
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) {
if (this.localModelsComponent && this.azureModelsComponent) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.localModelsComponent.removeComponents(this._formBuilder);
this.azureModelsComponent.addComponents(this._formBuilder);
this.registeredModelsComponent.removeComponents(this._formBuilder);
await this.azureModelsComponent.refresh();
}
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.localModelsComponent.removeComponents(this._formBuilder);
this.azureModelsComponent.removeComponents(this._formBuilder);
this.registeredModelsComponent.addComponents(this._formBuilder);
await this.registeredModelsComponent.refresh();
}
}
}
}
@@ -89,4 +103,21 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
public get title(): string {
return constants.modelSourcePageTitle;
}
public validate(): Promise<boolean> {
let validated = false;
if (this.modelResources && this.modelResources.data === ModelSourceType.Local && this.localModelsComponent) {
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure && this.azureModelsComponent) {
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.model !== undefined;
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
validated = this.registeredModelsComponent.data !== undefined;
}
if (!validated) {
this.showErrorMessage(constants.invalidModelToSelectError);
}
return Promise.resolve(validated);
}
}

View File

@@ -11,7 +11,8 @@ import { IDataComponent } from '../interfaces';
export enum ModelSourceType {
Local,
Azure
Azure,
RegisteredModels
}
/**
* View to pick model source
@@ -22,9 +23,10 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
private _flexContainer: azdata.FlexContainer | undefined;
private _amlModel: azdata.RadioButtonComponent | undefined;
private _localModel: azdata.RadioButtonComponent | undefined;
private _isLocalModel: boolean = true;
private _registeredModels: azdata.RadioButtonComponent | undefined;
private _sourceType: ModelSourceType = ModelSourceType.Local;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
}
@@ -38,7 +40,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
value: 'local',
name: 'modelLocation',
label: constants.localModelSource,
checked: true
checked: this._options[0] === ModelSourceType.Local
}).component();
@@ -47,26 +49,58 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
value: 'aml',
name: 'modelLocation',
label: constants.azureModelSource,
checked: this._options[0] === ModelSourceType.Azure
}).component();
this._registeredModels = modelBuilder.radioButton()
.withProperties({
value: 'registered',
name: 'modelLocation',
label: constants.registeredModelsSource,
checked: this._options[0] === ModelSourceType.RegisteredModels
}).component();
this._localModel.onDidClick(() => {
this._isLocalModel = true;
this._sourceType = ModelSourceType.Local;
this.sendRequest(SourceModelSelectedEventName);
});
this._amlModel.onDidClick(() => {
this._isLocalModel = false;
this._sourceType = ModelSourceType.Azure;
this.sendRequest(SourceModelSelectedEventName);
});
this._registeredModels.onDidClick(() => {
this._sourceType = ModelSourceType.RegisteredModels;
this.sendRequest(SourceModelSelectedEventName);
});
let components: azdata.RadioButtonComponent[] = [];
this._options.forEach(option => {
switch (option) {
case ModelSourceType.Local:
if (this._localModel) {
components.push(this._localModel);
}
break;
case ModelSourceType.Azure:
if (this._amlModel) {
components.push(this._amlModel);
}
break;
case ModelSourceType.RegisteredModels:
if (this._registeredModels) {
components.push(this._registeredModels);
}
break;
}
});
this._sourceType = this._options[0];
this._flexContainer = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'column',
justifyContent: 'space-between'
}).withItems([
this._localModel, this._amlModel]
).component();
}).withItems(components).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
@@ -92,7 +126,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
* Returns selected data
*/
public get data(): ModelSourceType {
return this._isLocalModel ? ModelSourceType.Local : ModelSourceType.Azure;
return this._sourceType;
}
/**

View File

@@ -8,7 +8,8 @@ import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { ViewBase } from '../viewBase';
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
@@ -16,13 +17,18 @@ export interface AzureResourceEventArgs extends AzureWorkspaceResource {
}
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
details?: RegisteredModel
details?: RegisteredModelDetails
}
export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs {
model?: WorkspaceModel;
}
export interface PredictModelEventArgs extends PredictParameters {
model?: RegisteredModel;
filePath?: string;
}
export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
filePath?: string;
}
@@ -32,11 +38,16 @@ export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
export const ListModelsEventName = 'listModels';
export const ListAzureModelsEventName = 'listAzureModels';
export const ListAccountsEventName = 'listAccounts';
export const ListDatabaseNamesEventName = 'listDatabaseNames';
export const ListTableNamesEventName = 'listTableNames';
export const ListColumnNamesEventName = 'listColumnNames';
export const ListSubscriptionsEventName = 'listSubscriptions';
export const ListGroupsEventName = 'listGroups';
export const ListWorkspacesEventName = 'listWorkspaces';
export const RegisterLocalModelEventName = 'registerLocalModel';
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
export const PredictModelEventName = 'predictModel';
export const RegisterModelEventName = 'registerModel';
export const SourceModelSelectedEventName = 'sourceModelSelected';
@@ -59,7 +70,12 @@ export abstract class ModelViewBase extends ViewBase {
RegisterLocalModelEventName,
RegisterAzureModelEventName,
RegisterModelEventName,
SourceModelSelectedEventName]);
SourceModelSelectedEventName,
ListDatabaseNamesEventName,
ListTableNamesEventName,
ListColumnNamesEventName,
PredictModelEventName,
DownloadAzureModelEventName]);
}
/**
@@ -91,6 +107,27 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(ListAccountsEventName);
}
/**
* lists database names
*/
public async listDatabaseNames(): Promise<string[]> {
return await this.sendDataRequest(ListDatabaseNamesEventName);
}
/**
* lists table names
*/
public async listTableNames(dbName: string): Promise<DatabaseTable[]> {
return await this.sendDataRequest(ListTableNamesEventName, dbName);
}
/**
* lists column names
*/
public async listColumnNames(table: DatabaseTable): Promise<string[]> {
return await this.sendDataRequest(ListColumnNamesEventName, table);
}
/**
* lists azure subscriptions
* @param account azure account
@@ -106,7 +143,7 @@ export abstract class ModelViewBase extends ViewBase {
* registers local model
* @param localFilePath local file path
*/
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModel | undefined): Promise<void> {
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
const args: RegisterLocalModelEventArgs = {
filePath: localFilePath,
details: details
@@ -114,17 +151,38 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(RegisterLocalModelEventName, args);
}
/**
* download azure model
* @param args azure resource
*/
public async downloadAzureModel(resource: AzureModelResource | undefined): Promise<string> {
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
}
/**
* registers azure model
* @param args azure resource
*/
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModel | undefined): Promise<void> {
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
const args: RegisterAzureModelEventArgs = Object.assign({}, resource, {
details: details
});
return await this.sendDataRequest(RegisterAzureModelEventName, args);
}
/**
* registers azure model
* @param args azure resource
*/
public async generatePredictScript(model: RegisteredModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> {
const args: PredictModelEventArgs = Object.assign({}, params, {
model: model,
filePath: filePath,
loadFromRegisteredModel: !filePath
});
return await this.sendDataRequest(PredictModelEventName, args);
}
/**
* list resource groups
* @param account azure account

View File

@@ -0,0 +1,168 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { ColumnsTable } from './columnsTable';
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
/**
* View to render filters to pick an azure resource
*/
export class ColumnsFilterComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
private _form: azdata.FormContainer | undefined;
private _databases: azdata.DropDownComponent | undefined;
private _tables: azdata.DropDownComponent | undefined;
private _columns: ColumnsTable | undefined;
private _dbNames: string[] = [];
private _tableNames: DatabaseTable[] = [];
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._databases = modelBuilder.dropDown().withProperties({
width: this.componentMaxLength
}).component();
this._tables = modelBuilder.dropDown().withProperties({
width: this.componentMaxLength
}).component();
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
this._databases.onValueChanged(async () => {
await this.onDatabaseSelected();
});
this._tables.onValueChanged(async () => {
await this.onTableSelected();
});
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._databases
}, {
title: constants.azureSubscription,
component: this._tables
}, {
title: constants.azureGroup,
component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._databases && this._tables && this._columns) {
formBuilder.addFormItems([{
title: constants.columnDatabase,
component: this._databases
}, {
title: constants.columnTable,
component: this._tables
}, {
title: constants.inputColumns,
component: this._columns.component
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._databases && this._tables && this._columns) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._databases
});
formBuilder.removeFormItem({
title: constants.azureSubscription,
component: this._tables
});
formBuilder.removeFormItem({
title: constants.azureGroup,
component: this._columns.component
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Returns selected data
*/
public get data(): PredictInputParameters | undefined {
return Object.assign({}, this.databaseTable, {
inputColumns: this.columnNames
});
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
this._dbNames = await this.listDatabaseNames();
if (this._databases && this._dbNames && this._dbNames.length > 0) {
this._databases.values = this._dbNames;
this._databases.value = this._dbNames[0];
}
await this.onDatabaseSelected();
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
private async onDatabaseSelected(): Promise<void> {
this._tableNames = await this.listTableNames(this.databaseName || '');
if (this._tables && this._tableNames && this._tableNames.length > 0) {
this._tables.values = this._tableNames.map(t => this.getTableFullName(t));
this._tables.value = this.getTableFullName(this._tableNames[0]);
}
await this.onTableSelected();
}
private getTableFullName(table: DatabaseTable): string {
return `${table.schema}.${table.tableName}`;
}
private async onTableSelected(): Promise<void> {
this._columns?.loadData(this.databaseTable);
}
private get databaseName(): string | undefined {
return <string>this._databases?.value;
}
private get databaseTable(): DatabaseTable {
let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._tables?.value);
return {
databaseName: this.databaseName,
tableName: selectedItem?.tableName,
schema: selectedItem?.schema
};
}
private get columnNames(): PredictColumn[] | undefined {
return this._columns?.data;
}
}

View File

@@ -0,0 +1,84 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { ColumnsFilterComponent } from './columnsFilterComponent';
import { OutputColumnsComponent } from './outputColumnsComponent';
import { PredictParameters } from '../../../prediction/interfaces';
/**
* View to pick model source
*/
export class ColumnsSelectionPage extends ModelViewBase implements IPageView, IDataComponent<PredictParameters> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public columnsFilterComponent: ColumnsFilterComponent | undefined;
public outputColumnsComponent: OutputColumnsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this);
this.columnsFilterComponent.registerComponent(modelBuilder);
this.columnsFilterComponent.addComponents(this._formBuilder);
this.refresh();
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
this.outputColumnsComponent.registerComponent(modelBuilder);
this.outputColumnsComponent.addComponents(this._formBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): PredictParameters | undefined {
return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ?
Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
undefined;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.columnsFilterComponent) {
await this.columnsFilterComponent.refresh();
}
if (this.outputColumnsComponent) {
await this.outputColumnsComponent.refresh();
}
}
}
/**
* Returns page title
*/
public get title(): string {
return constants.columnSelectionPageTitle;
}
}

View File

@@ -0,0 +1,155 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IDataComponent } from '../../interfaces';
import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces';
/**
* View to render azure models in a table
*/
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
private _table: azdata.DeclarativeTableComponent;
private _selectedColumns: PredictColumn[] = [];
private _columns: string[] | undefined;
/**
* Creates a view to render azure models in a table
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._table = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Name
displayName: constants.columnDatabase,
ariaLabel: constants.columnName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 120,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.inputName,
ariaLabel: constants.inputName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
}
public get component(): azdata.DeclarativeTableComponent {
return this._table;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadData(table: DatabaseTable): Promise<void> {
this._selectedColumns = [];
if (this._table) {
this._columns = await this.listColumnNames(table);
let tableData: any[][] = [];
if (this._columns) {
tableData = tableData.concat(this._columns.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
}
private createTableRow(column: string): any[] {
if (this._modelBuilder) {
let selectRowButton = this._modelBuilder.checkBox().withProperties({
width: 15,
height: 15,
checked: true
}).component();
let nameInputBox = this._modelBuilder.inputBox().withProperties({
value: '',
width: 150
}).component();
this._selectedColumns.push({ name: column });
selectRowButton.onChanged(() => {
if (selectRowButton.checked) {
if (!this._selectedColumns.find(x => x.name === column)) {
this._selectedColumns.push({ name: column });
}
} else {
if (this._selectedColumns.find(x => x.name === column)) {
this._selectedColumns = this._selectedColumns.filter(x => x.name !== column);
}
}
});
nameInputBox.onTextChanged(() => {
let selectedRow = this._selectedColumns.find(x => x.name === column);
if (selectedRow) {
selectedRow.displayName = nameInputBox.value;
}
});
return [column, nameInputBox, selectRowButton];
}
return [];
}
/**
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
return this._selectedColumns;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,124 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { PredictColumn } from '../../../prediction/interfaces';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 60;
export class OutputColumnsComponent extends ModelViewBase implements IDataComponent<PredictColumn[]> {
private _form: azdata.FormContainer | undefined;
private _flex: azdata.FlexContainer | undefined;
private _columnName: azdata.InputBoxComponent | undefined;
private _columnTypes: azdata.DropDownComponent | undefined;
private _dataTypes: string[] = [
'int',
'nvarchar(MAX)',
'varchar(MAX)',
'float',
'double',
'bit'
];
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._columnName = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength
}).component();
this._columnTypes = modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
let flex = modelBuilder.flexContainer()
.withLayout({
width: this._columnName.width
}).withItems([
this._columnName]
).component();
this._flex = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between',
width: this.componentMaxLength
}).withItems([
flex, this._columnTypes]
).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._flex
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.addFormItems([{
title: constants.outputColumns,
component: this._flex
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.removeFormItem({
title: constants.outputColumns,
component: this._flex
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
if (this._columnTypes) {
this._columnTypes.values = this._dataTypes;
this._columnTypes.value = this._dataTypes[0];
}
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
/**
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
return this._columnName && this._columnTypes ? [{
name: this._columnName.value || '',
dataType: <string>this._columnTypes.value || ''
}] : undefined;
}
}

View File

@@ -0,0 +1,111 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
import { AzureModelsComponent } from '../azureModelsComponent';
import * as constants from '../../../common/constants';
import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage';
import { ColumnsSelectionPage } from './columnsSelectionPage';
import { RegisteredModel } from '../../../modelManagement/interfaces';
/**
* Wizard to register a model
*/
export class PredictWizard extends ModelViewBase {
public modelSourcePage: ModelSourcePage | undefined;
//public modelDetailsPage: ModelDetailsPage | undefined;
public columnsSelectionPage: ColumnsSelectionPage | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
constructor(
apiWrapper: ApiWrapper,
root: string,
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
}
/**
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
[this.modelSourcePage,
this.columnsSelectionPage]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.predictModel;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
if (validated && pageInfo.newPage === undefined) {
wizard.cancelButton.enabled = false;
wizard.backButton.enabled = false;
await this.predict();
wizard.cancelButton.enabled = true;
wizard.backButton.enabled = true;
if (this._parentView) {
this._parentView?.refresh();
}
return true;
}
return validated;
});
wizard.open();
}
public get modelResources(): ModelSourcesComponent | undefined {
return this.modelSourcePage?.modelResources;
}
public get localModelsComponent(): LocalModelsComponent | undefined {
return this.modelSourcePage?.localModelsComponent;
}
public get azureModelsComponent(): AzureModelsComponent | undefined {
return this.modelSourcePage?.azureModelsComponent;
}
private async predict(): Promise<boolean> {
try {
let modelFilePath: string = '';
let registeredModel: RegisteredModel | undefined = undefined;
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
modelFilePath = this.localModelsComponent.data;
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data);
} else {
registeredModel = this.modelSourcePage?.registeredModelsComponent?.data;
}
await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data);
return true;
} catch (error) {
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
return false;
}
}
/**
* Refresh the pages
*/
public async refresh(): Promise<void> {
await this.wizardView?.refresh();
}
}

View File

@@ -5,11 +5,11 @@
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import { ModelViewBase, RegisterModelEventName } from './modelViewBase';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { CurrentModelsTable } from './currentModelsTable';
import { ApiWrapper } from '../../common/apiWrapper';
import { IPageView } from '../interfaces';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IPageView } from '../../interfaces';
/**
* View to render current registered models
@@ -33,28 +33,21 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._dataTable = new CurrentModelsTable(this._apiWrapper, modelBuilder, this);
this._dataTable = new CurrentModelsTable(this._apiWrapper, this);
this._dataTable.registerComponent(modelBuilder);
this._tableComponent = this._dataTable.component;
let registerButton = modelBuilder.button().withProperties({
label: constants.registerModelTitle,
width: this.buttonMaxLength
}).component();
registerButton.onDidClick(async () => {
await this.sendDataRequest(RegisterModelEventName);
});
let formModelBuilder = modelBuilder.formContainer();
let formModel = modelBuilder.formContainer()
.withFormItems([{
title: '',
component: registerButton
}, {
if (this._tableComponent) {
formModelBuilder.addFormItem({
component: this._tableComponent,
title: ''
}]).component();
});
}
this._loader = modelBuilder.loadingComponent()
.withItem(formModel)
.withItem(formModelBuilder.component())
.withProperties({
loading: true
}).component();

View File

@@ -4,24 +4,26 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { RegisteredModel } from '../../modelManagement/interfaces';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { RegisteredModel } from '../../../modelManagement/interfaces';
import { IDataComponent } from '../../interfaces';
/**
* View to render registered models table
*/
export class CurrentModelsTable extends ModelViewBase {
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<RegisteredModel> {
private _table: azdata.DeclarativeTableComponent;
private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined;
private _selectedModel: any;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._table = this.registerComponent(this._modelBuilder);
}
/**
@@ -29,6 +31,7 @@ export class CurrentModelsTable extends ModelViewBase {
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
this._modelBuilder = modelBuilder;
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
@@ -92,10 +95,23 @@ export class CurrentModelsTable extends ModelViewBase {
return this._table;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this.component) {
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this.component) {
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component });
}
}
/**
* Returns the component
*/
public get component(): azdata.DeclarativeTableComponent {
public get component(): azdata.DeclarativeTableComponent | undefined {
return this._table;
}
@@ -103,38 +119,45 @@ export class CurrentModelsTable extends ModelViewBase {
* Loads the data in the component
*/
public async loadData(): Promise<void> {
let models: RegisteredModel[] | undefined;
if (this._table) {
let models: RegisteredModel[] | undefined;
models = await this.listModels();
let tableData: any[][] = [];
models = await this.listModels();
let tableData: any[][] = [];
if (models) {
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
if (models) {
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
this._table.data = tableData;
}
private createTableRow(model: RegisteredModel): any[] {
if (this._modelBuilder) {
let editLanguageButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
let selectModelButton = this._modelBuilder.radioButton().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15
height: 15,
checked: false
}).component();
editLanguageButton.onDidClick(() => {
selectModelButton.onDidClick(() => {
this._selectedModel = model;
});
return [model.artifactName, model.title, model.created, editLanguageButton];
return [model.artifactName, model.title, model.created, selectModelButton];
}
return [];
}
/**
* Returns selected data
*/
public get data(): RegisteredModel | undefined {
return this._selectedModel;
}
/**
* Refreshes the view
*/

View File

@@ -4,15 +4,15 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import * as constants from '../../common/constants';
import { WizardView } from '../wizardView';
import { ModelSourcePage } from './modelSourcePage';
import { ModelDetailsPage } from './modelDetailsPage';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
import { AzureModelsComponent } from '../azureModelsComponent';
import * as constants from '../../../common/constants';
import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage';
import { ModelDetailsPage } from '../modelDetailsPage';
/**
* Wizard to register a model
@@ -47,19 +47,20 @@ export class RegisterModelWizard extends ModelViewBase {
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
if (pageInfo.newPage === undefined) {
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
if (validated && pageInfo.newPage === undefined) {
wizard.cancelButton.enabled = false;
wizard.backButton.enabled = false;
await this.registerModel();
let result = await this.registerModel();
wizard.cancelButton.enabled = true;
wizard.backButton.enabled = true;
if (this._parentView) {
this._parentView?.refresh();
await this._parentView?.refresh();
}
return true;
return result;
}
return true;
return validated;
});
wizard.open();

View File

@@ -5,10 +5,10 @@
import { CurrentModelsPage } from './currentModelsPage';
import { ModelViewBase } from './modelViewBase';
import * as constants from '../../common/constants';
import { ApiWrapper } from '../../common/apiWrapper';
import { DialogView } from '../dialogView';
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
import * as constants from '../../../common/constants';
import { ApiWrapper } from '../../../common/apiWrapper';
import { DialogView } from '../../dialogView';
/**
* Dialog to render registered model views
@@ -31,7 +31,13 @@ export class RegisteredModelsDialog extends ModelViewBase {
this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this);
let registerModelButton = this._apiWrapper.createButton(constants.registerModelTitle);
registerModelButton.onClick(async () => {
await this.sendDataRequest(RegisterModelEventName);
});
let dialog = this.dialogView.createDialog('', [this.currentLanguagesTab]);
dialog.customButtons = [registerModelButton];
this.mainViewPanel = dialog;
dialog.okButton.hidden = true;
dialog.cancelButton.label = constants.extLangDoneButtonText;

View File

@@ -4,6 +4,8 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../common/constants';
import { ApiWrapper } from '../common/apiWrapper';
import * as path from 'path';
@@ -21,8 +23,7 @@ export interface CallbackEventArgs {
}
export const CallEventNamePostfix = 'Callback';
export const LocalFileEventName = 'localFile';
export const LocalFolderEventName = 'localFolder';
export const LocalPathsEventName = 'localPaths';
/**
* Base class for views
@@ -51,7 +52,7 @@ export abstract class ViewBase extends EventEmitterCollection {
}
protected getEventNames(): string[] {
return [LocalFolderEventName, LocalFileEventName];
return [LocalPathsEventName];
}
protected getCallbackEventNames(): string[] {
@@ -118,12 +119,8 @@ export abstract class ViewBase extends EventEmitterCollection {
});
}
public async getLocalFilePath(): Promise<string> {
return await this.sendDataRequest(LocalFileEventName);
}
public async getLocalFolderPath(): Promise<string> {
return await this.sendDataRequest(LocalFolderEventName);
public async getLocalPaths(options: vscode.OpenDialogOptions): Promise<string[]> {
return await this.sendDataRequest(LocalPathsEventName, options);
}
public async getLocationTitle(): Promise<string> {
@@ -174,12 +171,12 @@ export abstract class ViewBase extends EventEmitterCollection {
}
public showErrorMessage(message: string, error?: any): void {
this.showMessage(`${message} ${constants.getErrorMessage(error)}`, azdata.window.MessageLevel.Error);
this.showMessage(`${message} ${error ? constants.getErrorMessage(error) : ''}`, azdata.window.MessageLevel.Error);
}
private showMessage(message: string, level: azdata.window.MessageLevel): void {
if (this._mainViewPanel) {
this._mainViewPanel.message = {
if (this.mainViewPanel) {
this.mainViewPanel.message = {
text: message,
level: level
};

View File

@@ -45,6 +45,19 @@ export class WizardView extends MainViewBase {
}
}
/**
* Adds wizard page
* @param page page
* @param index page index
*/
public removeWizardPage(page: IPageView, index: number): void {
if (this._wizard && this._pages[index] === page) {
this._pages = this._pages.splice(index);
this._wizard.removePage(index);
}
}
/**
*
* @param title Creates anew wizard
@@ -57,9 +70,21 @@ export class WizardView extends MainViewBase {
this._wizard.onPageChanged(async (info) => {
this.onWizardPageChanged(info);
});
return this._wizard;
}
public async validate(pageInfo: azdata.window.WizardPageChangeInfo): Promise<boolean> {
if (pageInfo.lastPage !== undefined) {
let idxLast = pageInfo.lastPage;
let lastPage = this._pages[idxLast];
if (lastPage && lastPage.validate) {
return await lastPage.validate();
}
}
return true;
}
private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
let idxLast = pageInfo.lastPage;
let lastPage = this._pages[idxLast];
@@ -73,4 +98,8 @@ export class WizardView extends MainViewBase {
page.onEnter();
}
}
public get wizard(): azdata.window.Wizard | undefined {
return this._wizard;
}
}