diff --git a/extensions/machine-learning-services/config.json b/extensions/machine-learning-services/config.json index 54ea491601..1ba95ab82e 100644 --- a/extensions/machine-learning-services/config.json +++ b/extensions/machine-learning-services/config.json @@ -26,6 +26,7 @@ "modelManagement": { "registeredModelsDatabaseName": "MlFlowDB", "registeredModelsTableName": "artifacts", + "registeredModelsTableSchemaName": "dbo", "amlModelManagementUrl": "modelmanagement.azureml.net", "amlExperienceUrl": "experiments.azureml.net", "amlApiVersion": "2018-11-19", diff --git a/extensions/machine-learning-services/package.json b/extensions/machine-learning-services/package.json index 39158f3c09..e401f79ef3 100644 --- a/extensions/machine-learning-services/package.json +++ b/extensions/machine-learning-services/package.json @@ -57,6 +57,10 @@ "command": "mls.command.managePackages", "title": "%mls.command.managePackages%" }, + { + "command": "mls.command.predictModel", + "title": "%mls.command.predictModel%" + }, { "command": "mls.command.manageModels", "title": "%mls.command.manageModels%" @@ -110,7 +114,7 @@ "mls.command.managePackages", "mls.command.manageLanguages", "mls.command.manageModels", - "mls.command.registerModel" + "mls.command.predictModel" ] } }, diff --git a/extensions/machine-learning-services/package.nls.json b/extensions/machine-learning-services/package.nls.json index 2a3ea86ca5..f42cc5dd92 100644 --- a/extensions/machine-learning-services/package.nls.json +++ b/extensions/machine-learning-services/package.nls.json @@ -7,8 +7,9 @@ "title.endpoints": "Endpoints", "mls.command.managePackages": "Manage Packages in SQL Server", "mls.command.manageLanguages": "Manage External Languages", - "mls.command.manageModels": "Manage Models", - "mls.command.registerModel": "Register Model", + "mls.command.predictModel": "Make prediction", + "mls.command.manageModels": "Manage models", + "mls.command.registerModel": "Register model", "mls.command.odbcdriver": "Install ODBC Driver for SQL Server", "mls.command.mlsdocs": "Machine Learning Services Documentation", "mls.configuration.title": "Machine Learning Services configurations", diff --git a/extensions/machine-learning-services/src/common/apiWrapper.ts b/extensions/machine-learning-services/src/common/apiWrapper.ts index fd269641cf..bf2cb4867c 100644 --- a/extensions/machine-learning-services/src/common/apiWrapper.ts +++ b/extensions/machine-learning-services/src/common/apiWrapper.ts @@ -105,4 +105,28 @@ export class ApiWrapper { public showQuickPick(items: T[] | Thenable, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable { return vscode.window.showQuickPick(items, options, token); } + + public listDatabases(connectionId: string): Thenable { + return azdata.connection.listDatabases(connectionId); + } + + public openTextDocument(options?: { language?: string; content?: string; }): Thenable { + return vscode.workspace.openTextDocument(options); + } + + public connect(fileUri: string, connectionId: string): Thenable { + return azdata.queryeditor.connect(fileUri, connectionId); + } + + public runQuery(fileUri: string, options?: Map, runCurrentQuery?: boolean): void { + azdata.queryeditor.runQuery(fileUri, options, runCurrentQuery); + } + + public showTextDocument(uri: vscode.Uri, options?: vscode.TextDocumentShowOptions): Thenable { + return vscode.window.showTextDocument(uri, options); + } + + public createButton(label: string, position?: azdata.window.DialogButtonPosition): azdata.window.Button { + return azdata.window.createButton(label, position); + } } diff --git a/extensions/machine-learning-services/src/common/constants.ts b/extensions/machine-learning-services/src/common/constants.ts index e663f7321e..3d80a20f1f 100644 --- a/extensions/machine-learning-services/src/common/constants.ts +++ b/extensions/machine-learning-services/src/common/constants.ts @@ -24,6 +24,7 @@ export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups'; // Tasks, commands // export const mlManageLanguagesCommand = 'mls.command.manageLanguages'; +export const mlsPredictModelCommand = 'mls.command.predictModel'; export const mlManageModelsCommand = 'mls.command.manageModels'; export const mlRegisterModelCommand = 'mls.command.registerModel'; export const mlManagePackagesCommand = 'mls.command.managePackages'; @@ -116,6 +117,12 @@ export const modelCreated = localize('models.created', "Date Created"); export const modelVersion = localize('models.version', "Version"); export const browseModels = localize('models.browseButton', "..."); export const azureAccount = localize('models.azureAccount', "Azure account"); +export const columnDatabase = localize('predict.columnDatabase', "Database"); +export const columnTable = localize('predict.columnTable', "Table"); +export const inputColumns = localize('predict.inputColumns', "Input columns"); +export const outputColumns = localize('predict.outputColumns', "Output column"); +export const columnName = localize('predict.columnName', "Name"); +export const inputName = localize('predict.inputName', "Input Name"); export const azureSubscription = localize('models.azureSubscription', "Azure subscription"); export const azureGroup = localize('models.azureGroup', "Azure resource group"); export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace"); @@ -125,18 +132,25 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model export const localModelsTitle = localize('models.localModelsTitle', "Local models"); export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location"); export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Ender model source details"); +export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Select input columns"); export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details"); export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file"); export const currentModelsTitle = localize('models.currentModelsTitle', "Models"); export const azureRegisterModel = localize('models.azureRegisterModel', "Register"); +export const predictModel = localize('models.predictModel', "Predict"); export const registerModelTitle = localize('models.RegisterWizard', "Register model"); +export const makePredictionTitle = localize('models.makePredictionTitle', "Make prediction"); export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully"); export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register"); export const localModelSource = localize('models.localModelSource', "Upload file"); export const azureModelSource = localize('models.azureModelSource', "Import from AzureML registry"); +export const registeredModelsSource = localize('models.registeredModelsSource', "Select managed models"); export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure"); export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource"); export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register"); +export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict"); +export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model"); +export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model"); diff --git a/extensions/machine-learning-services/src/common/queryRunner.ts b/extensions/machine-learning-services/src/common/queryRunner.ts index a22fb432d9..90f92a8739 100644 --- a/extensions/machine-learning-services/src/common/queryRunner.ts +++ b/extensions/machine-learning-services/src/common/queryRunner.ts @@ -163,4 +163,21 @@ export class QueryRunner { } return result; } + + /** + * Executes the query but doesn't fail it is fails + * @param connection SQL connection + * @param query query to run + */ + public async safeRunQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise { + try { + return await this.runQuery(connection, query); + } catch (error) { + console.log(error); + return undefined; + } + } } + + + diff --git a/extensions/machine-learning-services/src/common/utils.ts b/extensions/machine-learning-services/src/common/utils.ts index 23684e7d62..483eb9440c 100644 --- a/extensions/machine-learning-services/src/common/utils.ts +++ b/extensions/machine-learning-services/src/common/utils.ts @@ -11,6 +11,7 @@ import * as fs from 'fs'; import * as constants from '../common/constants'; import { promisify } from 'util'; import { ApiWrapper } from './apiWrapper'; +import { Config } from '../configurations/config'; export async function execCommandOnTempFile(content: string, command: (filePath: string) => Promise): Promise { let tempFilePath: string = ''; @@ -25,6 +26,11 @@ export async function execCommandOnTempFile(content: string, command: (filePa } } +export async function readFileInHex(filePath: string): Promise { + let buffer = await fs.promises.readFile(filePath); + return `0X${buffer.toString('hex')}`; +} + export async function exists(path: string): Promise { return promisify(fs.exists)(path); } @@ -109,8 +115,8 @@ export function isWindows(): boolean { * ' => '' * @param value The string to escape */ -export function doubleEscapeSingleQuotes(value: string): string { - return value.replace(/'/g, '\'\''); +export function doubleEscapeSingleQuotes(value: string | undefined): string { + return value ? value.replace(/'/g, '\'\'') : ''; } /** @@ -118,8 +124,8 @@ export function doubleEscapeSingleQuotes(value: string): string { * ' => '' * @param value The string to escape */ -export function doubleEscapeSingleBrackets(value: string): string { - return value.replace(/\[/g, '[[').replace(/\]/g, ']]'); +export function doubleEscapeSingleBrackets(value: string | undefined): string { + return value ? value.replace(/\[/g, '[[').replace(/\]/g, ']]') : ''; } /** @@ -176,3 +182,48 @@ export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Pr return choices[result.label] || false; } + +export function makeLinuxPath(filePath: string): string { + const parts = filePath.split('\\'); + return parts.join('/'); +} + +/** + * + * @param currentDb Wraps the given script with database switch scripts + * @param databaseName + * @param script + */ +export function getScriptWithDBChange(currentDb: string, databaseName: string, script: string): string { + if (!currentDb) { + currentDb = 'master'; + } + let escapedDbName = doubleEscapeSingleBrackets(databaseName); + let escapedCurrentDbName = doubleEscapeSingleBrackets(currentDb); + return ` + USE [${escapedDbName}] + ${script} + USE [${escapedCurrentDbName}] + `; +} + +/** + * Returns full name of model registration table + * @param config config + */ +export function getRegisteredModelsThreePartsName(config: Config) { + const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName); + const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName); + const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName); + return `[${dbName}].${schema}.[${tableName}]`; +} + +/** + * Returns full name of model registration table + * @param config config object + */ +export function getRegisteredModelsTowPartsName(config: Config) { + const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName); + const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName); + return `[${schema}].[${tableName}]`; +} diff --git a/extensions/machine-learning-services/src/configurations/config.ts b/extensions/machine-learning-services/src/configurations/config.ts index ddbd1248f9..1da8625ef3 100644 --- a/extensions/machine-learning-services/src/configurations/config.ts +++ b/extensions/machine-learning-services/src/configurations/config.ts @@ -82,6 +82,13 @@ export class Config { return this._configValues.modelManagement.registeredModelsTableName; } + /** + * Returns registered models table schema name + */ + public get registeredModelTableSchemaName(): string { + return this._configValues.modelManagement.registeredModelsTableSchemaName; + } + /** * Returns registered models table name */ diff --git a/extensions/machine-learning-services/src/controllers/mainController.ts b/extensions/machine-learning-services/src/controllers/mainController.ts index 461505a5fc..4c4e92e346 100644 --- a/extensions/machine-learning-services/src/controllers/mainController.ts +++ b/extensions/machine-learning-services/src/controllers/mainController.ts @@ -22,6 +22,7 @@ import { ModelManagementController } from '../views/models/modelManagementContro import { RegisteredModelService } from '../modelManagement/registeredModelService'; import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService'; import { ModelImporter } from '../modelManagement/modelImporter'; +import { PredictService } from '../prediction/predictService'; /** * The main controller class that initializes the extension @@ -109,7 +110,9 @@ export default class MainController implements vscode.Disposable { // let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter); let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel); - let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, azureModelsService, registeredModelService); + let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config); + let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, + azureModelsService, registeredModelService, predictService); this._apiWrapper.registerCommand(constants.mlManageLanguagesCommand, (async () => { await languageController.manageLanguages(); @@ -120,6 +123,9 @@ export default class MainController implements vscode.Disposable { this._apiWrapper.registerCommand(constants.mlRegisterModelCommand, (async () => { await modelManagementController.registerModel(); })); + this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => { + await modelManagementController.predictModel(); + })); this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => { await packageManager.installDependencies(); })); @@ -135,6 +141,9 @@ export default class MainController implements vscode.Disposable { this._apiWrapper.registerTaskHandler(constants.mlRegisterModelCommand, async () => { await modelManagementController.registerModel(); }); + this._apiWrapper.registerTaskHandler(constants.mlsPredictModelCommand, async () => { + await modelManagementController.predictModel(); + }); this._apiWrapper.registerTaskHandler(constants.mlOdbcDriverCommand, async () => { await this.serverConfigManager.openOdbcDriverDocuments(); }); diff --git a/extensions/machine-learning-services/src/modelManagement/interfaces.ts b/extensions/machine-learning-services/src/modelManagement/interfaces.ts index 0a0af77e1c..212c3adc34 100644 --- a/extensions/machine-learning-services/src/modelManagement/interfaces.ts +++ b/extensions/machine-learning-services/src/modelManagement/interfaces.ts @@ -48,13 +48,19 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & { /** * An interface representing registered model */ -export interface RegisteredModel { - id?: number, - artifactName?: string, - title?: string, - created?: string, - version?: string - description?: string +export interface RegisteredModel extends RegisteredModelDetails { + id: number; + artifactName: string; +} + +/** + * An interface representing registered model + */ +export interface RegisteredModelDetails { + title: string; + created?: string; + version?: string; + description?: string; } /** diff --git a/extensions/machine-learning-services/src/modelManagement/modelImporter.ts b/extensions/machine-learning-services/src/modelManagement/modelImporter.ts index 007ac143e1..ad00576055 100644 --- a/extensions/machine-learning-services/src/modelManagement/modelImporter.ts +++ b/extensions/machine-learning-services/src/modelManagement/modelImporter.ts @@ -12,6 +12,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid'; import * as utils from '../common/utils'; import { PackageManager } from '../packageManagement/packageManager'; import * as constants from '../common/constants'; +import * as os from 'os'; /** * Service to import model to database @@ -39,8 +40,8 @@ export class ModelImporter { protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise { - const parts = modelFolderPath.split('\\'); - modelFolderPath = parts.join('/'); + let home = utils.makeLinuxPath(os.homedir()); + modelFolderPath = utils.makeLinuxPath(modelFolderPath); let credentials = await this._apiWrapper.getCredentials(connection.connectionId); @@ -51,9 +52,12 @@ export class ModelImporter { const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : ''; let scripts: string[] = [ 'import mlflow.onnx', + `tracking_uri = "file://${home}/mlruns"`, + 'print(tracking_uri)', 'import onnx', 'from mlflow.tracking.client import MlflowClient', `onx = onnx.load("${modelFolderPath}")`, + `mlflow.set_tracking_uri(tracking_uri)`, 'client = MlflowClient()', `exp_name = "${experimentId}"`, `db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`, diff --git a/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts b/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts index af47b7e8e7..6720bc2563 100644 --- a/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts +++ b/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts @@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper'; import * as utils from '../common/utils'; import { Config } from '../configurations/config'; import { QueryRunner } from '../common/queryRunner'; -import { RegisteredModel } from './interfaces'; +import { RegisteredModel, RegisteredModelDetails } from './interfaces'; import { ModelImporter } from './modelImporter'; import * as constants from '../common/constants'; @@ -32,7 +32,10 @@ export class RegisteredModelService { let connection = await this.getCurrentConnection(); let list: RegisteredModel[] = []; if (connection) { - let result = await this.runRegisteredModelsListQuery(connection); + let query = this.getConfigureQuery(connection.databaseName); + await this._queryRunner.safeRunQuery(connection, query); + query = this.registeredModelsQuery(); + let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { result.rows.forEach(row => { list.push(this.loadModelData(row)); @@ -57,7 +60,8 @@ export class RegisteredModelService { let connection = await this.getCurrentConnection(); let updatedModel: RegisteredModel | undefined = undefined; if (connection) { - let result = await this.runUpdateModelQuery(connection, model); + const query = this.getUpdateModelScript(connection.databaseName, model); + let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { const row = result.rows[0]; updatedModel = this.loadModelData(row); @@ -66,7 +70,7 @@ export class RegisteredModelService { return updatedModel; } - public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) { + public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) { let connection = await this.getCurrentConnection(); if (connection) { let currentModels = await this.getRegisteredModels(); @@ -93,35 +97,14 @@ export class RegisteredModelService { return await this._apiWrapper.getCurrentConnection(); } - private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise { - try { - return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName)); - } catch { - return undefined; - } + private getConfigureQuery(currentDatabaseName: string): string { + return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable()); } - private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise { - try { - return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model)); - } catch { - return undefined; - } - } - - private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string { - if (!currentDatabaseName) { - currentDatabaseName = 'master'; - } - let escapedTableName = utils.doubleEscapeSingleBrackets(tableName); - let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName); - let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName); - + private registeredModelsQuery(): string { return ` - ${this.configureTable(databaseName, tableName)} - USE [${escapedCurrentDbName}] SELECT artifact_id, artifact_name, name, description, version, created - FROM [${escapedDbName}].dbo.[${escapedTableName}] + FROM ${utils.getRegisteredModelsThreePartsName(this._config)} WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml' Order by artifact_id `; @@ -133,52 +116,74 @@ export class RegisteredModelService { * @param databaseName * @param tableName */ - private configureTable(databaseName: string, tableName: string): string { - let escapedTableName = utils.doubleEscapeSingleBrackets(tableName); - let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName); + private configureTable(): string { + let databaseName = this._config.registeredModelDatabaseName; + let tableName = this._config.registeredModelTableName; + let schemaName = this._config.registeredModelTableSchemaName; return ` - USE [${escapedDbName}] + IF NOT EXISTS ( + SELECT [name] + FROM sys.databases + WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}' + ) + CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}] + GO + USE [${utils.doubleEscapeSingleBrackets(databaseName)}] IF EXISTS - ( SELECT [name] - FROM sys.tables - WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}' + ( SELECT [t.name], [s.name] + FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id + WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}' + AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}' ) BEGIN - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name') - ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version') - ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created') + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name') + ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version') + ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created') BEGIN - ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL - ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created + ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL + ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created END - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description') - ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description') + ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL + END + Else + BEGIN + CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}( + [artifact_id] [int] IDENTITY(1,1) NOT NULL, + [artifact_name] [varchar](256) NOT NULL, + [group_path] [varchar](256) NOT NULL, + [artifact_content] [varbinary](max) NOT NULL, + [artifact_initial_size] [bigint] NULL, + [name] [varchar](256) NULL, + [version] [varchar](256) NULL, + [created] [datetime] NULL, + [description] [varchar](256) NULL, + CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED + ( + [artifact_id] ASC + )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] + ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] + ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] END `; } - private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string { + private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string { + let updateScript = ` + UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)} + SET + name = '${utils.doubleEscapeSingleQuotes(model.title || '')}', + version = '${utils.doubleEscapeSingleQuotes(model.version || '')}', + description = '${utils.doubleEscapeSingleQuotes(model.description || '')}' + WHERE artifact_id = ${model.id}`; - if (!currentDatabaseName) { - currentDatabaseName = 'master'; - } - let escapedTableName = utils.doubleEscapeSingleBrackets(tableName); - let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName); - let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName); return ` - USE [${escapedDbName}] - UPDATE ${escapedTableName} - SET - name = '${utils.doubleEscapeSingleQuotes(model.title || '')}', - version = '${utils.doubleEscapeSingleQuotes(model.version || '')}', - description = '${utils.doubleEscapeSingleQuotes(model.description || '')}' - WHERE artifact_id = ${model.id}; - - USE [${escapedCurrentDbName}] - SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}] + ${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)} + SELECT artifact_id, artifact_name, name, description, version, created + FROM ${utils.getRegisteredModelsThreePartsName(this._config)} WHERE artifact_id = ${model.id}; `; } diff --git a/extensions/machine-learning-services/src/prediction/interfaces.ts b/extensions/machine-learning-services/src/prediction/interfaces.ts new file mode 100644 index 0000000000..5fcb789edd --- /dev/null +++ b/extensions/machine-learning-services/src/prediction/interfaces.ts @@ -0,0 +1,24 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +export interface PredictColumn { + name: string; + dataType?: string; + displayName?: string; +} + +export interface DatabaseTable { + databaseName: string | undefined; + tableName: string | undefined; + schema: string | undefined +} + +export interface PredictInputParameters extends DatabaseTable { + inputColumns: PredictColumn[] | undefined +} + +export interface PredictParameters extends PredictInputParameters { + outputColumns: PredictColumn[] | undefined +} diff --git a/extensions/machine-learning-services/src/prediction/predictService.ts b/extensions/machine-learning-services/src/prediction/predictService.ts new file mode 100644 index 0000000000..31b714c82d --- /dev/null +++ b/extensions/machine-learning-services/src/prediction/predictService.ts @@ -0,0 +1,203 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; + +import { ApiWrapper } from '../common/apiWrapper'; +import { QueryRunner } from '../common/queryRunner'; +import * as utils from '../common/utils'; +import { RegisteredModel } from '../modelManagement/interfaces'; +import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces'; +import { Config } from '../configurations/config'; + +/** + * Service to make prediction + */ +export class PredictService { + + /** + * Creates new instance + */ + constructor( + private _apiWrapper: ApiWrapper, + private _queryRunner: QueryRunner, + private _config: Config) { + } + + /** + * Returns the list of databases + */ + public async getDatabaseList(): Promise { + let connection = await this.getCurrentConnection(); + if (connection) { + return await this._apiWrapper.listDatabases(connection.connectionId); + } + return []; + } + + /** + * Generates prediction script given model info and predict parameters + * @param predictParams predict parameters + * @param registeredModel model parameters + */ + public async generatePredictScript( + predictParams: PredictParameters, + registeredModel: RegisteredModel | undefined, + filePath: string | undefined + ): Promise { + let connection = await this.getCurrentConnection(); + let query = ''; + if (registeredModel && registeredModel.id) { + query = this.getPredictScriptWithModelId( + registeredModel.id, + predictParams.inputColumns || [], + predictParams.outputColumns || [], + predictParams); + } else if (filePath) { + let modelBytes = await utils.readFileInHex(filePath || ''); + query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [], + predictParams.outputColumns || [], + predictParams); + } + let document = await this._apiWrapper.openTextDocument({ + language: 'sql', + content: query + }); + await this._apiWrapper.showTextDocument(document.uri); + await this._apiWrapper.connect(document.uri.toString(), connection.connectionId); + this._apiWrapper.runQuery(document.uri.toString(), undefined, false); + return query; + } + + /** + * Returns list of tables given database name + * @param databaseName database name + */ + public async getTableList(databaseName: string): Promise { + let connection = await this.getCurrentConnection(); + let list: DatabaseTable[] = []; + if (connection) { + let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName)); + let result = await this._queryRunner.safeRunQuery(connection, query); + if (result && result.rows && result.rows.length > 0) { + result.rows.forEach(row => { + list.push({ + databaseName: databaseName, + tableName: row[0].displayValue, + schema: row[1].displayValue + }); + }); + } + } + return list; + } + + /** + *Returns list of column names of a database + * @param databaseTable table info + */ + public async getTableColumnsList(databaseTable: DatabaseTable): Promise { + let connection = await this.getCurrentConnection(); + let list: string[] = []; + if (connection && databaseTable.databaseName) { + const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable)); + let result = await this._queryRunner.safeRunQuery(connection, query); + if (result && result.rows && result.rows.length > 0) { + result.rows.forEach(row => { + list.push(row[0].displayValue); + }); + } + } + return list; + } + + private async getCurrentConnection(): Promise { + return await this._apiWrapper.getCurrentConnection(); + } + + private getTableColumnsScript(databaseTable: DatabaseTable): string { + return ` +SELECT COLUMN_NAME,* +FROM INFORMATION_SCHEMA.COLUMNS +WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}' +AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}' +AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}' + `; + } + + private getTablesScript(databaseName: string): string { + return ` +SELECT TABLE_NAME,TABLE_SCHEMA +FROM INFORMATION_SCHEMA.TABLES +WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}' + `; + } + + private getPredictScriptWithModelId( + modelId: number, + columns: PredictColumn[], + outputColumns: PredictColumn[], + databaseNameTable: DatabaseTable): string { + return ` +DECLARE @model VARBINARY(max) = ( + SELECT artifact_content + FROM ${utils.getRegisteredModelsThreePartsName(this._config)} + WHERE artifact_id = ${modelId} +); +WITH predict_input +AS ( + SELECT TOP 1000 + ${this.getColumnNames(columns, 'pi')} + FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi +) +SELECT +${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')} +FROM PREDICT(MODEL = @model, DATA = predict_input) +WITH ( + ${this.getColumnTypes(outputColumns)} +) AS p +`; + } + + private getPredictScriptWithModelBytes( + modelBytes: string, + columns: PredictColumn[], + outputColumns: PredictColumn[], + databaseNameTable: DatabaseTable): string { + return ` +WITH predict_input +AS ( + SELECT TOP 1000 + ${this.getColumnNames(columns, 'pi')} + FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi +) +SELECT +${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')} +FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input) +WITH ( + ${this.getColumnTypes(outputColumns)} +) AS p +`; + } + + private getColumnNames(columns: PredictColumn[], tableName: string) { + return columns.map(c => { + return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`; + }).join(',\n'); + } + + private getInputColumnNames(columns: PredictColumn[], tableName: string) { + return columns.map(c => { + return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`; + }).join(',\n'); + } + + private getColumnTypes(columns: PredictColumn[]) { + return columns.map(c => { + return `${c.name} ${c.dataType}`; + }).join(',\n'); + } +} + diff --git a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts index c51f342680..21c37a93ac 100644 --- a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts @@ -13,7 +13,7 @@ import { azureResource } from '../../../typings/azure-resource'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { ViewBase } from '../../../views/viewBase'; import { WorkspaceModel } from '../../../modelManagement/interfaces'; -import { RegisterModelWizard } from '../../../views/models/registerModelWizard'; +import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard'; describe('Register Model Wizard', () => { it('Should create view components successfully ', async function (): Promise { @@ -74,7 +74,8 @@ describe('Register Model Wizard', () => { let localModels: RegisteredModel[] = [ { id: 1, - artifactName: 'model' + artifactName: 'model', + title: 'model' } ]; view.on(ListModelsEventName, () => { diff --git a/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts b/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts index 610995001d..9c4d84c33f 100644 --- a/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts @@ -6,7 +6,7 @@ import * as should from 'should'; import 'mocha'; import { createContext } from './utils'; -import { RegisteredModelsDialog } from '../../../views/models/registeredModelsDialog'; +import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog'; import { ListModelsEventName } from '../../../views/models/modelViewBase'; import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ViewBase } from '../../../views/viewBase'; @@ -30,7 +30,8 @@ describe('Registered Models Dialog', () => { let models: RegisteredModel[] = [ { id: 1, - artifactName: 'model' + artifactName: 'model', + title: '' } ]; view.on(ListModelsEventName, () => { diff --git a/extensions/machine-learning-services/src/test/views/utils.ts b/extensions/machine-learning-services/src/test/views/utils.ts index 9e0e370109..3ae57623f4 100644 --- a/extensions/machine-learning-services/src/test/views/utils.ts +++ b/extensions/machine-learning-services/src/test/views/utils.ts @@ -246,6 +246,7 @@ export function createViewContext(): ViewTestContext { modelView: undefined!, valid: true }; + apiWrapper.setup(x => x.createButton(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => dialogButton); apiWrapper.setup(x => x.createTab(TypeMoq.It.isAny())).returns(() => tab); apiWrapper.setup(x => x.createWizard(TypeMoq.It.isAny())).returns(() => wizard); apiWrapper.setup(x => x.createWizardPage(TypeMoq.It.isAny())).returns(() => wizardPage); diff --git a/extensions/machine-learning-services/src/views/controllerBase.ts b/extensions/machine-learning-services/src/views/controllerBase.ts index c13a22f175..22babbbf74 100644 --- a/extensions/machine-learning-services/src/views/controllerBase.ts +++ b/extensions/machine-learning-services/src/views/controllerBase.ts @@ -3,7 +3,9 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import { ViewBase, LocalFileEventName, LocalFolderEventName } from './viewBase'; +import * as vscode from 'vscode'; + +import { ViewBase, LocalPathsEventName } from './viewBase'; import { ApiWrapper } from '../common/apiWrapper'; /** @@ -36,11 +38,8 @@ export abstract class ControllerBase { * @param view view */ public registerEvents(view: ViewBase): void { - view.on(LocalFileEventName, async () => { - await this.executeAction(view, LocalFileEventName, this.getLocalFilePath, this._apiWrapper); - }); - view.on(LocalFolderEventName, async () => { - await this.executeAction(view, LocalFolderEventName, this.getLocalFolderPath, this._apiWrapper); + view.on(LocalPathsEventName, async (args) => { + await this.executeAction(view, LocalPathsEventName, this.getLocalPaths, this._apiWrapper, args); }); } @@ -48,25 +47,8 @@ export abstract class ControllerBase { * Returns local file path picked by the user * @param apiWrapper apiWrapper */ - public async getLocalFilePath(apiWrapper: ApiWrapper): Promise { - let result = await apiWrapper.showOpenDialog({ - canSelectFiles: true, - canSelectFolders: false, - canSelectMany: false - }); - return result && result.length > 0 ? result[0].fsPath : ''; - } - - /** - * Returns local folder path picked by the user - * @param apiWrapper apiWrapper - */ - public async getLocalFolderPath(apiWrapper: ApiWrapper): Promise { - let result = await apiWrapper.showOpenDialog({ - canSelectFiles: false, - canSelectFolders: true, - canSelectMany: false - }); - return result && result.length > 0 ? result[0].fsPath : ''; + public async getLocalPaths(apiWrapper: ApiWrapper, options: vscode.OpenDialogOptions): Promise { + let result = await apiWrapper.showOpenDialog(options); + return result ? result?.map(x => x.fsPath) : []; } } diff --git a/extensions/machine-learning-services/src/views/interfaces.ts b/extensions/machine-learning-services/src/views/interfaces.ts index 294dd82d78..028a656e16 100644 --- a/extensions/machine-learning-services/src/views/interfaces.ts +++ b/extensions/machine-learning-services/src/views/interfaces.ts @@ -16,6 +16,7 @@ export interface IPageView { component: azdata.Component | undefined; onEnter?: () => Promise; onLeave?: () => Promise; + validate?: () => Promise; refresh: () => Promise; viewPanel: azdata.window.ModelViewPanel | undefined; title: string; @@ -32,3 +33,4 @@ export interface AzureModelResource extends AzureWorkspaceResource { model?: WorkspaceModel; } + diff --git a/extensions/machine-learning-services/src/views/models/localModelsComponent.ts b/extensions/machine-learning-services/src/views/models/localModelsComponent.ts index dae7d6b4c7..41eec22ada 100644 --- a/extensions/machine-learning-services/src/views/models/localModelsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/localModelsComponent.ts @@ -4,6 +4,8 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; +import * as vscode from 'vscode'; + import { ModelViewBase } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; @@ -43,9 +45,17 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen } }).component(); this._localBrowse.onDidClick(async () => { - const filePath = await this.getLocalFilePath(); + + let options: vscode.OpenDialogOptions = { + canSelectFiles: true, + canSelectFolders: false, + canSelectMany: false, + filters: { 'ONNX File': ['onnx'] } + }; + + const filePaths = await this.getLocalPaths(options); if (this._localPath) { - this._localPath.value = filePath; + this._localPath.value = filePaths && filePaths.length > 0 ? filePaths[0] : ''; } }); diff --git a/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts b/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts index aa7bb6aab2..3465ff0b95 100644 --- a/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts @@ -8,12 +8,12 @@ import { ModelViewBase } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IDataComponent } from '../interfaces'; -import { RegisteredModel } from '../../modelManagement/interfaces'; +import { RegisteredModelDetails } from '../../modelManagement/interfaces'; /** * View to pick local models file */ -export class ModelDetailsComponent extends ModelViewBase implements IDataComponent { +export class ModelDetailsComponent extends ModelViewBase implements IDataComponent { private _form: azdata.FormContainer | undefined; private _nameComponent: azdata.InputBoxComponent | undefined; @@ -81,9 +81,9 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone /** * Returns selected data */ - public get data(): RegisteredModel { + public get data(): RegisteredModelDetails { return { - title: this._nameComponent?.value, + title: this._nameComponent?.value || '', description: this._descriptionComponent?.value }; } diff --git a/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts b/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts index 8cccb19240..f2baa27bd6 100644 --- a/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts +++ b/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts @@ -9,12 +9,12 @@ import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IPageView, IDataComponent } from '../interfaces'; import { ModelDetailsComponent } from './modelDetailsComponent'; -import { RegisteredModel } from '../../modelManagement/interfaces'; +import { RegisteredModelDetails } from '../../modelManagement/interfaces'; /** * View to pick model details */ -export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent { +export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent { private _form: azdata.FormContainer | undefined; private _formBuilder: azdata.FormBuilder | undefined; @@ -43,7 +43,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC /** * Returns selected data */ - public get data(): RegisteredModel | undefined { + public get data(): RegisteredModelDetails | undefined { return this.modelDetails?.data; } @@ -66,4 +66,13 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC public get title(): string { return constants.modelDetailsPageTitle; } + + public validate(): Promise { + if (this.data && this.data.title) { + return Promise.resolve(true); + } else { + this.showErrorMessage(constants.modelNameRequiredError); + return Promise.resolve(false); + } + } } diff --git a/extensions/machine-learning-services/src/views/models/modelManagementController.ts b/extensions/machine-learning-services/src/views/models/modelManagementController.ts index db8a4dbd6f..5076d56b37 100644 --- a/extensions/machine-learning-services/src/views/models/modelManagementController.ts +++ b/extensions/machine-learning-services/src/views/models/modelManagementController.ts @@ -9,14 +9,23 @@ import { azureResource } from '../../typings/azure-resource'; import { ApiWrapper } from '../../common/apiWrapper'; import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; -import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces'; +import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces'; +import { PredictParameters, DatabaseTable } from '../../prediction/interfaces'; import { RegisteredModelService } from '../../modelManagement/registeredModelService'; -import { RegisteredModelsDialog } from './registeredModelsDialog'; -import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName } from './modelViewBase'; +import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog'; +import { + AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, + ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, + RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, + ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs +} from './modelViewBase'; import { ControllerBase } from '../controllerBase'; -import { RegisterModelWizard } from './registerModelWizard'; +import { RegisterModelWizard } from './registerModels/registerModelWizard'; import * as fs from 'fs'; import * as constants from '../../common/constants'; +import { PredictWizard } from './prediction/predictWizard'; +import { AzureModelResource } from '../interfaces'; +import { PredictService } from '../../prediction/predictService'; /** * Model management UI controller @@ -30,7 +39,8 @@ export class ModelManagementController extends ControllerBase { apiWrapper: ApiWrapper, private _root: string, private _amlService: AzureModelRegistryService, - private _registeredModelService: RegisteredModelService) { + private _registeredModelService: RegisteredModelService, + private _predictService: PredictService) { super(apiWrapper); } @@ -56,6 +66,23 @@ export class ModelManagementController extends ControllerBase { return view; } + /** + * Opens the wizard for prediction + */ + public async predictModel(): Promise { + + let view = new PredictWizard(this._apiWrapper, this._root); + + this.registerEvents(view); + + // Open view + // + view.open(); + await view.refresh(); + return view; + } + + /** * Register events in the main view * @param view main view @@ -102,6 +129,28 @@ export class ModelManagementController extends ControllerBase { await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService, registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details); }); + view.on(DownloadAzureModelEventName, async (arg) => { + let registerArgs = arg; + await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService, + registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model); + }); + view.on(ListDatabaseNamesEventName, async () => { + await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService); + }); + view.on(ListTableNamesEventName, async (arg) => { + let dbName = arg; + await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName); + }); + view.on(ListColumnNamesEventName, async (arg) => { + let tableColumnsArgs = arg; + await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService, + tableColumnsArgs); + }); + view.on(PredictModelEventName, async (arg) => { + let predictArgs = arg; + await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService, + predictArgs, predictArgs.model, predictArgs.filePath); + }); view.on(SourceModelSelectedEventName, () => { view.refresh(); }); @@ -158,7 +207,7 @@ export class ModelManagementController extends ControllerBase { return await service.getModels(account, subscription, resourceGroup, workspace) || []; } - private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise { + private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise { if (filePath) { await service.registerLocalModel(filePath, details); } else { @@ -175,7 +224,7 @@ export class ModelManagementController extends ControllerBase { resourceGroup: azureResource.AzureResource | undefined, workspace: Workspace | undefined, model: WorkspaceModel | undefined, - details: RegisteredModel | undefined): Promise { + details: RegisteredModelDetails | undefined): Promise { if (!account || !subscription || !resourceGroup || !workspace || !model || !details) { throw Error(constants.invalidAzureResourceError); } @@ -188,4 +237,47 @@ export class ModelManagementController extends ControllerBase { throw Error(constants.invalidModelToRegisterError); } } + + public async getDatabaseList(predictService: PredictService): Promise { + return await predictService.getDatabaseList(); + } + + public async getTableList(predictService: PredictService, databaseName: string): Promise { + return await predictService.getTableList(databaseName); + } + + public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise { + return await predictService.getTableColumnsList(databaseTable); + } + + private async generatePredictScript( + predictService: PredictService, + predictParams: PredictParameters, + registeredModel: RegisteredModel | undefined, + filePath: string | undefined + ): Promise { + if (!predictParams) { + throw Error(constants.invalidModelToPredictError); + } + const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath); + return result; + } + + private async downloadAzureModel( + azureService: AzureModelRegistryService, + account: azdata.Account | undefined, + subscription: azureResource.AzureResourceSubscription | undefined, + resourceGroup: azureResource.AzureResource | undefined, + workspace: Workspace | undefined, + model: WorkspaceModel | undefined): Promise { + if (!account || !subscription || !resourceGroup || !workspace || !model) { + throw Error(constants.invalidAzureResourceError); + } + const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model); + if (filePath) { + return filePath; + } else { + throw Error(constants.invalidModelToRegisterError); + } + } } diff --git a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts index dcf1d0f285..a3f521f4f2 100644 --- a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts +++ b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts @@ -11,6 +11,7 @@ import { IPageView, IDataComponent } from '../interfaces'; import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent'; import { LocalModelsComponent } from './localModelsComponent'; import { AzureModelsComponent } from './azureModelsComponent'; +import { CurrentModelsTable } from './registerModels/currentModelsTable'; /** * View to pick model source @@ -22,8 +23,9 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo public modelResources: ModelSourcesComponent | undefined; public localModelsComponent: LocalModelsComponent | undefined; public azureModelsComponent: AzureModelsComponent | undefined; + public registeredModelsComponent: CurrentModelsTable | undefined; - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) { super(apiWrapper, parent.root, parent); } @@ -34,13 +36,15 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this._formBuilder = modelBuilder.formContainer(); - this.modelResources = new ModelSourcesComponent(this._apiWrapper, this); + this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options); this.modelResources.registerComponent(modelBuilder); this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this); this.localModelsComponent.registerComponent(modelBuilder); this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this); this.azureModelsComponent.registerComponent(modelBuilder); this.modelResources.addComponents(this._formBuilder); + this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this); + this.registeredModelsComponent.registerComponent(modelBuilder); this.refresh(); this._form = this._formBuilder.component(); return this._form; @@ -66,19 +70,29 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo public async refresh(): Promise { if (this._formBuilder) { if (this.modelResources && this.modelResources.data === ModelSourceType.Local) { - if (this.localModelsComponent && this.azureModelsComponent) { + if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { this.azureModelsComponent.removeComponents(this._formBuilder); + this.registeredModelsComponent.removeComponents(this._formBuilder); this.localModelsComponent.addComponents(this._formBuilder); await this.localModelsComponent.refresh(); } } else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) { - if (this.localModelsComponent && this.azureModelsComponent) { + if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { this.localModelsComponent.removeComponents(this._formBuilder); this.azureModelsComponent.addComponents(this._formBuilder); + this.registeredModelsComponent.removeComponents(this._formBuilder); await this.azureModelsComponent.refresh(); } + } else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels) { + if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { + this.localModelsComponent.removeComponents(this._formBuilder); + this.azureModelsComponent.removeComponents(this._formBuilder); + this.registeredModelsComponent.addComponents(this._formBuilder); + await this.registeredModelsComponent.refresh(); + } + } } } @@ -89,4 +103,21 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo public get title(): string { return constants.modelSourcePageTitle; } + + public validate(): Promise { + let validated = false; + if (this.modelResources && this.modelResources.data === ModelSourceType.Local && this.localModelsComponent) { + validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0; + + } else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure && this.azureModelsComponent) { + validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.model !== undefined; + + } else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels && this.registeredModelsComponent) { + validated = this.registeredModelsComponent.data !== undefined; + } + if (!validated) { + this.showErrorMessage(constants.invalidModelToSelectError); + } + return Promise.resolve(validated); + } } diff --git a/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts b/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts index fd0f240928..ef542c58df 100644 --- a/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts +++ b/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts @@ -11,7 +11,8 @@ import { IDataComponent } from '../interfaces'; export enum ModelSourceType { Local, - Azure + Azure, + RegisteredModels } /** * View to pick model source @@ -22,9 +23,10 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone private _flexContainer: azdata.FlexContainer | undefined; private _amlModel: azdata.RadioButtonComponent | undefined; private _localModel: azdata.RadioButtonComponent | undefined; - private _isLocalModel: boolean = true; + private _registeredModels: azdata.RadioButtonComponent | undefined; + private _sourceType: ModelSourceType = ModelSourceType.Local; - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) { super(apiWrapper, parent.root, parent); } @@ -38,7 +40,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone value: 'local', name: 'modelLocation', label: constants.localModelSource, - checked: true + checked: this._options[0] === ModelSourceType.Local }).component(); @@ -47,26 +49,58 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone value: 'aml', name: 'modelLocation', label: constants.azureModelSource, + checked: this._options[0] === ModelSourceType.Azure + }).component(); + + this._registeredModels = modelBuilder.radioButton() + .withProperties({ + value: 'registered', + name: 'modelLocation', + label: constants.registeredModelsSource, + checked: this._options[0] === ModelSourceType.RegisteredModels }).component(); this._localModel.onDidClick(() => { - this._isLocalModel = true; + this._sourceType = ModelSourceType.Local; this.sendRequest(SourceModelSelectedEventName); }); this._amlModel.onDidClick(() => { - this._isLocalModel = false; + this._sourceType = ModelSourceType.Azure; this.sendRequest(SourceModelSelectedEventName); }); + this._registeredModels.onDidClick(() => { + this._sourceType = ModelSourceType.RegisteredModels; + this.sendRequest(SourceModelSelectedEventName); + }); + let components: azdata.RadioButtonComponent[] = []; + this._options.forEach(option => { + switch (option) { + case ModelSourceType.Local: + if (this._localModel) { + components.push(this._localModel); + } + break; + case ModelSourceType.Azure: + if (this._amlModel) { + components.push(this._amlModel); + } + break; + case ModelSourceType.RegisteredModels: + if (this._registeredModels) { + components.push(this._registeredModels); + } + break; + } + }); + this._sourceType = this._options[0]; this._flexContainer = modelBuilder.flexContainer() .withLayout({ flexFlow: 'column', justifyContent: 'space-between' - }).withItems([ - this._localModel, this._amlModel] - ).component(); + }).withItems(components).component(); this._form = modelBuilder.formContainer().withFormItems([{ title: '', @@ -92,7 +126,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone * Returns selected data */ public get data(): ModelSourceType { - return this._isLocalModel ? ModelSourceType.Local : ModelSourceType.Azure; + return this._sourceType; } /** diff --git a/extensions/machine-learning-services/src/views/models/modelViewBase.ts b/extensions/machine-learning-services/src/views/models/modelViewBase.ts index ac4382068b..6b1b39da06 100644 --- a/extensions/machine-learning-services/src/views/models/modelViewBase.ts +++ b/extensions/machine-learning-services/src/views/models/modelViewBase.ts @@ -8,7 +8,8 @@ import * as azdata from 'azdata'; import { azureResource } from '../../typings/azure-resource'; import { ApiWrapper } from '../../common/apiWrapper'; import { ViewBase } from '../viewBase'; -import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces'; +import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces'; +import { PredictParameters, DatabaseTable } from '../../prediction/interfaces'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { AzureWorkspaceResource, AzureModelResource } from '../interfaces'; @@ -16,13 +17,18 @@ export interface AzureResourceEventArgs extends AzureWorkspaceResource { } export interface RegisterModelEventArgs extends AzureWorkspaceResource { - details?: RegisteredModel + details?: RegisteredModelDetails } export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs { model?: WorkspaceModel; } +export interface PredictModelEventArgs extends PredictParameters { + model?: RegisteredModel; + filePath?: string; +} + export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs { filePath?: string; } @@ -32,11 +38,16 @@ export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs { export const ListModelsEventName = 'listModels'; export const ListAzureModelsEventName = 'listAzureModels'; export const ListAccountsEventName = 'listAccounts'; +export const ListDatabaseNamesEventName = 'listDatabaseNames'; +export const ListTableNamesEventName = 'listTableNames'; +export const ListColumnNamesEventName = 'listColumnNames'; export const ListSubscriptionsEventName = 'listSubscriptions'; export const ListGroupsEventName = 'listGroups'; export const ListWorkspacesEventName = 'listWorkspaces'; export const RegisterLocalModelEventName = 'registerLocalModel'; export const RegisterAzureModelEventName = 'registerAzureLocalModel'; +export const DownloadAzureModelEventName = 'downloadAzureLocalModel'; +export const PredictModelEventName = 'predictModel'; export const RegisterModelEventName = 'registerModel'; export const SourceModelSelectedEventName = 'sourceModelSelected'; @@ -59,7 +70,12 @@ export abstract class ModelViewBase extends ViewBase { RegisterLocalModelEventName, RegisterAzureModelEventName, RegisterModelEventName, - SourceModelSelectedEventName]); + SourceModelSelectedEventName, + ListDatabaseNamesEventName, + ListTableNamesEventName, + ListColumnNamesEventName, + PredictModelEventName, + DownloadAzureModelEventName]); } /** @@ -91,6 +107,27 @@ export abstract class ModelViewBase extends ViewBase { return await this.sendDataRequest(ListAccountsEventName); } + /** + * lists database names + */ + public async listDatabaseNames(): Promise { + return await this.sendDataRequest(ListDatabaseNamesEventName); + } + + /** + * lists table names + */ + public async listTableNames(dbName: string): Promise { + return await this.sendDataRequest(ListTableNamesEventName, dbName); + } + + /** + * lists column names + */ + public async listColumnNames(table: DatabaseTable): Promise { + return await this.sendDataRequest(ListColumnNamesEventName, table); + } + /** * lists azure subscriptions * @param account azure account @@ -106,7 +143,7 @@ export abstract class ModelViewBase extends ViewBase { * registers local model * @param localFilePath local file path */ - public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModel | undefined): Promise { + public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModelDetails | undefined): Promise { const args: RegisterLocalModelEventArgs = { filePath: localFilePath, details: details @@ -114,17 +151,38 @@ export abstract class ModelViewBase extends ViewBase { return await this.sendDataRequest(RegisterLocalModelEventName, args); } + /** + * download azure model + * @param args azure resource + */ + public async downloadAzureModel(resource: AzureModelResource | undefined): Promise { + return await this.sendDataRequest(DownloadAzureModelEventName, resource); + } + /** * registers azure model * @param args azure resource */ - public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModel | undefined): Promise { + public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModelDetails | undefined): Promise { const args: RegisterAzureModelEventArgs = Object.assign({}, resource, { details: details }); return await this.sendDataRequest(RegisterAzureModelEventName, args); } + /** + * registers azure model + * @param args azure resource + */ + public async generatePredictScript(model: RegisteredModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise { + const args: PredictModelEventArgs = Object.assign({}, params, { + model: model, + filePath: filePath, + loadFromRegisteredModel: !filePath + }); + return await this.sendDataRequest(PredictModelEventName, args); + } + /** * list resource groups * @param account azure account diff --git a/extensions/machine-learning-services/src/views/models/prediction/columnsFilterComponent.ts b/extensions/machine-learning-services/src/views/models/prediction/columnsFilterComponent.ts new file mode 100644 index 0000000000..65b9f53e89 --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/prediction/columnsFilterComponent.ts @@ -0,0 +1,168 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import * as constants from '../../../common/constants'; +import { IDataComponent } from '../../interfaces'; +import { ColumnsTable } from './columnsTable'; +import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces'; + +/** + * View to render filters to pick an azure resource + */ +export class ColumnsFilterComponent extends ModelViewBase implements IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _databases: azdata.DropDownComponent | undefined; + private _tables: azdata.DropDownComponent | undefined; + private _columns: ColumnsTable | undefined; + private _dbNames: string[] = []; + private _tableNames: DatabaseTable[] = []; + + /** + * Creates a new view + */ + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + super(apiWrapper, parent.root, parent); + } + + /** + * Register components + * @param modelBuilder model builder + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + this._databases = modelBuilder.dropDown().withProperties({ + width: this.componentMaxLength + }).component(); + this._tables = modelBuilder.dropDown().withProperties({ + width: this.componentMaxLength + }).component(); + this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this); + + this._databases.onValueChanged(async () => { + await this.onDatabaseSelected(); + }); + + this._tables.onValueChanged(async () => { + await this.onTableSelected(); + }); + + + this._form = modelBuilder.formContainer().withFormItems([{ + title: constants.azureAccount, + component: this._databases + }, { + title: constants.azureSubscription, + component: this._tables + }, { + title: constants.azureGroup, + component: this._columns.component + }]).component(); + return this._form; + } + + public addComponents(formBuilder: azdata.FormBuilder) { + if (this._databases && this._tables && this._columns) { + formBuilder.addFormItems([{ + title: constants.columnDatabase, + component: this._databases + }, { + title: constants.columnTable, + component: this._tables + }, { + title: constants.inputColumns, + component: this._columns.component + }]); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + if (this._databases && this._tables && this._columns) { + formBuilder.removeFormItem({ + title: constants.azureAccount, + component: this._databases + }); + formBuilder.removeFormItem({ + title: constants.azureSubscription, + component: this._tables + }); + formBuilder.removeFormItem({ + title: constants.azureGroup, + component: this._columns.component + }); + } + } + + /** + * Returns the created component + */ + public get component(): azdata.Component | undefined { + return this._form; + } + + /** + * Returns selected data + */ + public get data(): PredictInputParameters | undefined { + return Object.assign({}, this.databaseTable, { + inputColumns: this.columnNames + }); + } + + /** + * loads data in the components + */ + public async loadData(): Promise { + this._dbNames = await this.listDatabaseNames(); + if (this._databases && this._dbNames && this._dbNames.length > 0) { + this._databases.values = this._dbNames; + this._databases.value = this._dbNames[0]; + } + await this.onDatabaseSelected(); + } + + /** + * refreshes the view + */ + public async refresh(): Promise { + await this.loadData(); + } + + private async onDatabaseSelected(): Promise { + this._tableNames = await this.listTableNames(this.databaseName || ''); + if (this._tables && this._tableNames && this._tableNames.length > 0) { + this._tables.values = this._tableNames.map(t => this.getTableFullName(t)); + this._tables.value = this.getTableFullName(this._tableNames[0]); + } + await this.onTableSelected(); + } + + private getTableFullName(table: DatabaseTable): string { + return `${table.schema}.${table.tableName}`; + } + + private async onTableSelected(): Promise { + this._columns?.loadData(this.databaseTable); + } + + private get databaseName(): string | undefined { + return this._databases?.value; + } + + private get databaseTable(): DatabaseTable { + let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._tables?.value); + return { + databaseName: this.databaseName, + tableName: selectedItem?.tableName, + schema: selectedItem?.schema + }; + } + + private get columnNames(): PredictColumn[] | undefined { + return this._columns?.data; + } +} diff --git a/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts b/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts new file mode 100644 index 0000000000..f0c622442d --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts @@ -0,0 +1,84 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import * as constants from '../../../common/constants'; +import { IPageView, IDataComponent } from '../../interfaces'; +import { ColumnsFilterComponent } from './columnsFilterComponent'; +import { OutputColumnsComponent } from './outputColumnsComponent'; +import { PredictParameters } from '../../../prediction/interfaces'; + +/** + * View to pick model source + */ +export class ColumnsSelectionPage extends ModelViewBase implements IPageView, IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _formBuilder: azdata.FormBuilder | undefined; + public columnsFilterComponent: ColumnsFilterComponent | undefined; + public outputColumnsComponent: OutputColumnsComponent | undefined; + + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + super(apiWrapper, parent.root, parent); + } + + /** + * + * @param modelBuilder Register components + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + this._formBuilder = modelBuilder.formContainer(); + this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this); + this.columnsFilterComponent.registerComponent(modelBuilder); + this.columnsFilterComponent.addComponents(this._formBuilder); + this.refresh(); + + this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this); + this.outputColumnsComponent.registerComponent(modelBuilder); + this.outputColumnsComponent.addComponents(this._formBuilder); + this.refresh(); + this._form = this._formBuilder.component(); + return this._form; + } + + /** + * Returns selected data + */ + public get data(): PredictParameters | undefined { + return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ? + Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) : + undefined; + } + + /** + * Returns the component + */ + public get component(): azdata.Component | undefined { + return this._form; + } + + /** + * Refreshes the view + */ + public async refresh(): Promise { + if (this._formBuilder) { + if (this.columnsFilterComponent) { + await this.columnsFilterComponent.refresh(); + } + if (this.outputColumnsComponent) { + await this.outputColumnsComponent.refresh(); + } + } + } + + /** + * Returns page title + */ + public get title(): string { + return constants.columnSelectionPageTitle; + } +} diff --git a/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts new file mode 100644 index 0000000000..7230f7f63a --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts @@ -0,0 +1,155 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as constants from '../../../common/constants'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import { IDataComponent } from '../../interfaces'; +import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces'; + +/** + * View to render azure models in a table + */ +export class ColumnsTable extends ModelViewBase implements IDataComponent { + + private _table: azdata.DeclarativeTableComponent; + private _selectedColumns: PredictColumn[] = []; + private _columns: string[] | undefined; + + /** + * Creates a view to render azure models in a table + */ + constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) { + super(apiWrapper, parent.root, parent); + this._table = this.registerComponent(this._modelBuilder); + } + + /** + * Register components + * @param modelBuilder model builder + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent { + this._table = modelBuilder.declarativeTable() + .withProperties( + { + columns: [ + { // Name + displayName: constants.columnDatabase, + ariaLabel: constants.columnName, + valueType: azdata.DeclarativeDataType.string, + isReadOnly: true, + width: 120, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Action + displayName: constants.inputName, + ariaLabel: constants.inputName, + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Action + displayName: '', + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + } + ], + data: [], + ariaLabel: constants.mlsConfigTitle + }) + .component(); + return this._table; + } + + public get component(): azdata.DeclarativeTableComponent { + return this._table; + } + + /** + * Load data in the component + * @param workspaceResource Azure workspace + */ + public async loadData(table: DatabaseTable): Promise { + this._selectedColumns = []; + if (this._table) { + this._columns = await this.listColumnNames(table); + let tableData: any[][] = []; + + if (this._columns) { + tableData = tableData.concat(this._columns.map(model => this.createTableRow(model))); + } + + this._table.data = tableData; + } + } + + private createTableRow(column: string): any[] { + if (this._modelBuilder) { + let selectRowButton = this._modelBuilder.checkBox().withProperties({ + + width: 15, + height: 15, + checked: true + }).component(); + let nameInputBox = this._modelBuilder.inputBox().withProperties({ + value: '', + width: 150 + }).component(); + this._selectedColumns.push({ name: column }); + selectRowButton.onChanged(() => { + if (selectRowButton.checked) { + if (!this._selectedColumns.find(x => x.name === column)) { + this._selectedColumns.push({ name: column }); + } + } else { + if (this._selectedColumns.find(x => x.name === column)) { + this._selectedColumns = this._selectedColumns.filter(x => x.name !== column); + } + } + }); + nameInputBox.onTextChanged(() => { + let selectedRow = this._selectedColumns.find(x => x.name === column); + if (selectedRow) { + selectedRow.displayName = nameInputBox.value; + } + }); + return [column, nameInputBox, selectRowButton]; + } + + return []; + } + + /** + * Returns selected data + */ + public get data(): PredictColumn[] | undefined { + return this._selectedColumns; + } + + /** + * Refreshes the view + */ + public async refresh(): Promise { + } +} diff --git a/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts b/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts new file mode 100644 index 0000000000..35782e9492 --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts @@ -0,0 +1,124 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import * as constants from '../../../common/constants'; +import { IDataComponent } from '../../interfaces'; +import { PredictColumn } from '../../../prediction/interfaces'; + +/** + * View to render filters to pick an azure resource + */ +const componentWidth = 60; +export class OutputColumnsComponent extends ModelViewBase implements IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _flex: azdata.FlexContainer | undefined; + private _columnName: azdata.InputBoxComponent | undefined; + private _columnTypes: azdata.DropDownComponent | undefined; + private _dataTypes: string[] = [ + 'int', + 'nvarchar(MAX)', + 'varchar(MAX)', + 'float', + 'double', + 'bit' + ]; + + /** + * Creates a new view + */ + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + super(apiWrapper, parent.root, parent); + } + + /** + * Register components + * @param modelBuilder model builder + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + this._columnName = modelBuilder.inputBox().withProperties({ + width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength + }).component(); + this._columnTypes = modelBuilder.dropDown().withProperties({ + width: componentWidth + }).component(); + + let flex = modelBuilder.flexContainer() + .withLayout({ + width: this._columnName.width + }).withItems([ + this._columnName] + ).component(); + this._flex = modelBuilder.flexContainer() + .withLayout({ + flexFlow: 'row', + justifyContent: 'space-between', + width: this.componentMaxLength + }).withItems([ + flex, this._columnTypes] + ).component(); + + this._form = modelBuilder.formContainer().withFormItems([{ + title: constants.azureAccount, + component: this._flex + }]).component(); + return this._form; + } + + public addComponents(formBuilder: azdata.FormBuilder) { + if (this._flex) { + formBuilder.addFormItems([{ + title: constants.outputColumns, + component: this._flex + }]); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + if (this._flex) { + formBuilder.removeFormItem({ + title: constants.outputColumns, + component: this._flex + }); + } + } + + /** + * Returns the created component + */ + public get component(): azdata.Component | undefined { + return this._form; + } + + /** + * loads data in the components + */ + public async loadData(): Promise { + if (this._columnTypes) { + this._columnTypes.values = this._dataTypes; + this._columnTypes.value = this._dataTypes[0]; + } + } + + /** + * refreshes the view + */ + public async refresh(): Promise { + await this.loadData(); + } + + /** + * Returns selected data + */ + public get data(): PredictColumn[] | undefined { + return this._columnName && this._columnTypes ? [{ + name: this._columnName.value || '', + dataType: this._columnTypes.value || '' + }] : undefined; + } +} diff --git a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts new file mode 100644 index 0000000000..c7c3784f72 --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts @@ -0,0 +1,111 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent'; +import { LocalModelsComponent } from '../localModelsComponent'; +import { AzureModelsComponent } from '../azureModelsComponent'; +import * as constants from '../../../common/constants'; +import { WizardView } from '../../wizardView'; +import { ModelSourcePage } from '../modelSourcePage'; +import { ColumnsSelectionPage } from './columnsSelectionPage'; +import { RegisteredModel } from '../../../modelManagement/interfaces'; + +/** + * Wizard to register a model + */ +export class PredictWizard extends ModelViewBase { + + public modelSourcePage: ModelSourcePage | undefined; + //public modelDetailsPage: ModelDetailsPage | undefined; + public columnsSelectionPage: ColumnsSelectionPage | undefined; + public wizardView: WizardView | undefined; + private _parentView: ModelViewBase | undefined; + + constructor( + apiWrapper: ApiWrapper, + root: string, + parent?: ModelViewBase) { + super(apiWrapper, root); + this._parentView = parent; + } + + /** + * Opens a dialog to manage packages used by notebooks. + */ + public open(): void { + this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]); + this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this); + this.wizardView = new WizardView(this._apiWrapper); + + let wizard = this.wizardView.createWizard(constants.makePredictionTitle, + [this.modelSourcePage, + this.columnsSelectionPage]); + + this.mainViewPanel = wizard; + wizard.doneButton.label = constants.predictModel; + wizard.generateScriptButton.hidden = true; + wizard.displayPageTitles = true; + wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => { + let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; + if (validated && pageInfo.newPage === undefined) { + wizard.cancelButton.enabled = false; + wizard.backButton.enabled = false; + await this.predict(); + wizard.cancelButton.enabled = true; + wizard.backButton.enabled = true; + if (this._parentView) { + this._parentView?.refresh(); + } + return true; + + } + return validated; + }); + + wizard.open(); + } + + public get modelResources(): ModelSourcesComponent | undefined { + return this.modelSourcePage?.modelResources; + } + + public get localModelsComponent(): LocalModelsComponent | undefined { + return this.modelSourcePage?.localModelsComponent; + } + + public get azureModelsComponent(): AzureModelsComponent | undefined { + return this.modelSourcePage?.azureModelsComponent; + } + + private async predict(): Promise { + try { + let modelFilePath: string = ''; + let registeredModel: RegisteredModel | undefined = undefined; + if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) { + modelFilePath = this.localModelsComponent.data; + } else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) { + modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data); + } else { + registeredModel = this.modelSourcePage?.registeredModelsComponent?.data; + } + + await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data); + return true; + } catch (error) { + this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`); + return false; + } + } + + /** + * Refresh the pages + */ + public async refresh(): Promise { + await this.wizardView?.refresh(); + } +} diff --git a/extensions/machine-learning-services/src/views/models/currentModelsPage.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts similarity index 74% rename from extensions/machine-learning-services/src/views/models/currentModelsPage.ts rename to extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts index 3647534861..b113858ef7 100644 --- a/extensions/machine-learning-services/src/views/models/currentModelsPage.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts @@ -5,11 +5,11 @@ import * as azdata from 'azdata'; -import * as constants from '../../common/constants'; -import { ModelViewBase, RegisterModelEventName } from './modelViewBase'; +import * as constants from '../../../common/constants'; +import { ModelViewBase } from '../modelViewBase'; import { CurrentModelsTable } from './currentModelsTable'; -import { ApiWrapper } from '../../common/apiWrapper'; -import { IPageView } from '../interfaces'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import { IPageView } from '../../interfaces'; /** * View to render current registered models @@ -33,28 +33,21 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView { * @param modelBuilder register the components */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._dataTable = new CurrentModelsTable(this._apiWrapper, modelBuilder, this); + this._dataTable = new CurrentModelsTable(this._apiWrapper, this); + this._dataTable.registerComponent(modelBuilder); this._tableComponent = this._dataTable.component; - let registerButton = modelBuilder.button().withProperties({ - label: constants.registerModelTitle, - width: this.buttonMaxLength - }).component(); - registerButton.onDidClick(async () => { - await this.sendDataRequest(RegisterModelEventName); - }); + let formModelBuilder = modelBuilder.formContainer(); - let formModel = modelBuilder.formContainer() - .withFormItems([{ - title: '', - component: registerButton - }, { + if (this._tableComponent) { + formModelBuilder.addFormItem({ component: this._tableComponent, title: '' - }]).component(); + }); + } this._loader = modelBuilder.loadingComponent() - .withItem(formModel) + .withItem(formModelBuilder.component()) .withProperties({ loading: true }).component(); diff --git a/extensions/machine-learning-services/src/views/models/currentModelsTable.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts similarity index 60% rename from extensions/machine-learning-services/src/views/models/currentModelsTable.ts rename to extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts index 54f745222b..34b91d488c 100644 --- a/extensions/machine-learning-services/src/views/models/currentModelsTable.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts @@ -4,24 +4,26 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import * as constants from '../../common/constants'; -import { ModelViewBase } from './modelViewBase'; -import { ApiWrapper } from '../../common/apiWrapper'; -import { RegisteredModel } from '../../modelManagement/interfaces'; +import * as constants from '../../../common/constants'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import { RegisteredModel } from '../../../modelManagement/interfaces'; +import { IDataComponent } from '../../interfaces'; /** * View to render registered models table */ -export class CurrentModelsTable extends ModelViewBase { +export class CurrentModelsTable extends ModelViewBase implements IDataComponent { - private _table: azdata.DeclarativeTableComponent; + private _table: azdata.DeclarativeTableComponent | undefined; + private _modelBuilder: azdata.ModelBuilder | undefined; + private _selectedModel: any; /** * Creates new view */ - constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { super(apiWrapper, parent.root, parent); - this._table = this.registerComponent(this._modelBuilder); } /** @@ -29,6 +31,7 @@ export class CurrentModelsTable extends ModelViewBase { * @param modelBuilder register the components */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent { + this._modelBuilder = modelBuilder; this._table = modelBuilder.declarativeTable() .withProperties( { @@ -92,10 +95,23 @@ export class CurrentModelsTable extends ModelViewBase { return this._table; } + public addComponents(formBuilder: azdata.FormBuilder) { + if (this.component) { + formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component }); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + if (this.component) { + formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component }); + } + } + + /** * Returns the component */ - public get component(): azdata.DeclarativeTableComponent { + public get component(): azdata.DeclarativeTableComponent | undefined { return this._table; } @@ -103,38 +119,45 @@ export class CurrentModelsTable extends ModelViewBase { * Loads the data in the component */ public async loadData(): Promise { - let models: RegisteredModel[] | undefined; + if (this._table) { + let models: RegisteredModel[] | undefined; - models = await this.listModels(); - let tableData: any[][] = []; + models = await this.listModels(); + let tableData: any[][] = []; - if (models) { - tableData = tableData.concat(models.map(model => this.createTableRow(model))); + if (models) { + tableData = tableData.concat(models.map(model => this.createTableRow(model))); + } + + this._table.data = tableData; } - - this._table.data = tableData; } private createTableRow(model: RegisteredModel): any[] { if (this._modelBuilder) { - let editLanguageButton = this._modelBuilder.button().withProperties({ - label: '', - title: constants.deleteTitle, - iconPath: { - dark: this.asAbsolutePath('images/dark/edit_inverse.svg'), - light: this.asAbsolutePath('images/light/edit.svg') - }, + let selectModelButton = this._modelBuilder.radioButton().withProperties({ + name: 'amlModel', + value: model.id, width: 15, - height: 15 + height: 15, + checked: false }).component(); - editLanguageButton.onDidClick(() => { + selectModelButton.onDidClick(() => { + this._selectedModel = model; }); - return [model.artifactName, model.title, model.created, editLanguageButton]; + return [model.artifactName, model.title, model.created, selectModelButton]; } return []; } + /** + * Returns selected data + */ + public get data(): RegisteredModel | undefined { + return this._selectedModel; + } + /** * Refreshes the view */ diff --git a/extensions/machine-learning-services/src/views/models/registerModelWizard.ts b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts similarity index 78% rename from extensions/machine-learning-services/src/views/models/registerModelWizard.ts rename to extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts index 84e9d77095..73013595f1 100644 --- a/extensions/machine-learning-services/src/views/models/registerModelWizard.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts @@ -4,15 +4,15 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase } from './modelViewBase'; -import { ApiWrapper } from '../../common/apiWrapper'; -import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent'; -import { LocalModelsComponent } from './localModelsComponent'; -import { AzureModelsComponent } from './azureModelsComponent'; -import * as constants from '../../common/constants'; -import { WizardView } from '../wizardView'; -import { ModelSourcePage } from './modelSourcePage'; -import { ModelDetailsPage } from './modelDetailsPage'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent'; +import { LocalModelsComponent } from '../localModelsComponent'; +import { AzureModelsComponent } from '../azureModelsComponent'; +import * as constants from '../../../common/constants'; +import { WizardView } from '../../wizardView'; +import { ModelSourcePage } from '../modelSourcePage'; +import { ModelDetailsPage } from '../modelDetailsPage'; /** * Wizard to register a model @@ -47,19 +47,20 @@ export class RegisterModelWizard extends ModelViewBase { wizard.generateScriptButton.hidden = true; wizard.displayPageTitles = true; wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => { - if (pageInfo.newPage === undefined) { + let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; + if (validated && pageInfo.newPage === undefined) { wizard.cancelButton.enabled = false; wizard.backButton.enabled = false; - await this.registerModel(); + let result = await this.registerModel(); wizard.cancelButton.enabled = true; wizard.backButton.enabled = true; if (this._parentView) { - this._parentView?.refresh(); + await this._parentView?.refresh(); } - return true; + return result; } - return true; + return validated; }); wizard.open(); diff --git a/extensions/machine-learning-services/src/views/models/registeredModelsDialog.ts b/extensions/machine-learning-services/src/views/models/registerModels/registeredModelsDialog.ts similarity index 75% rename from extensions/machine-learning-services/src/views/models/registeredModelsDialog.ts rename to extensions/machine-learning-services/src/views/models/registerModels/registeredModelsDialog.ts index aeb6bdbde4..c999beb611 100644 --- a/extensions/machine-learning-services/src/views/models/registeredModelsDialog.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/registeredModelsDialog.ts @@ -5,10 +5,10 @@ import { CurrentModelsPage } from './currentModelsPage'; -import { ModelViewBase } from './modelViewBase'; -import * as constants from '../../common/constants'; -import { ApiWrapper } from '../../common/apiWrapper'; -import { DialogView } from '../dialogView'; +import { ModelViewBase, RegisterModelEventName } from '../modelViewBase'; +import * as constants from '../../../common/constants'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import { DialogView } from '../../dialogView'; /** * Dialog to render registered model views @@ -31,7 +31,13 @@ export class RegisteredModelsDialog extends ModelViewBase { this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this); + let registerModelButton = this._apiWrapper.createButton(constants.registerModelTitle); + registerModelButton.onClick(async () => { + await this.sendDataRequest(RegisterModelEventName); + }); + let dialog = this.dialogView.createDialog('', [this.currentLanguagesTab]); + dialog.customButtons = [registerModelButton]; this.mainViewPanel = dialog; dialog.okButton.hidden = true; dialog.cancelButton.label = constants.extLangDoneButtonText; diff --git a/extensions/machine-learning-services/src/views/viewBase.ts b/extensions/machine-learning-services/src/views/viewBase.ts index 912a9781e3..3702bab3bf 100644 --- a/extensions/machine-learning-services/src/views/viewBase.ts +++ b/extensions/machine-learning-services/src/views/viewBase.ts @@ -4,6 +4,8 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; +import * as vscode from 'vscode'; + import * as constants from '../common/constants'; import { ApiWrapper } from '../common/apiWrapper'; import * as path from 'path'; @@ -21,8 +23,7 @@ export interface CallbackEventArgs { } export const CallEventNamePostfix = 'Callback'; -export const LocalFileEventName = 'localFile'; -export const LocalFolderEventName = 'localFolder'; +export const LocalPathsEventName = 'localPaths'; /** * Base class for views @@ -51,7 +52,7 @@ export abstract class ViewBase extends EventEmitterCollection { } protected getEventNames(): string[] { - return [LocalFolderEventName, LocalFileEventName]; + return [LocalPathsEventName]; } protected getCallbackEventNames(): string[] { @@ -118,12 +119,8 @@ export abstract class ViewBase extends EventEmitterCollection { }); } - public async getLocalFilePath(): Promise { - return await this.sendDataRequest(LocalFileEventName); - } - - public async getLocalFolderPath(): Promise { - return await this.sendDataRequest(LocalFolderEventName); + public async getLocalPaths(options: vscode.OpenDialogOptions): Promise { + return await this.sendDataRequest(LocalPathsEventName, options); } public async getLocationTitle(): Promise { @@ -174,12 +171,12 @@ export abstract class ViewBase extends EventEmitterCollection { } public showErrorMessage(message: string, error?: any): void { - this.showMessage(`${message} ${constants.getErrorMessage(error)}`, azdata.window.MessageLevel.Error); + this.showMessage(`${message} ${error ? constants.getErrorMessage(error) : ''}`, azdata.window.MessageLevel.Error); } private showMessage(message: string, level: azdata.window.MessageLevel): void { - if (this._mainViewPanel) { - this._mainViewPanel.message = { + if (this.mainViewPanel) { + this.mainViewPanel.message = { text: message, level: level }; diff --git a/extensions/machine-learning-services/src/views/wizardView.ts b/extensions/machine-learning-services/src/views/wizardView.ts index 6c1f613ce3..33976ec0fb 100644 --- a/extensions/machine-learning-services/src/views/wizardView.ts +++ b/extensions/machine-learning-services/src/views/wizardView.ts @@ -45,6 +45,19 @@ export class WizardView extends MainViewBase { } } + /** + * Adds wizard page + * @param page page + * @param index page index + */ + public removeWizardPage(page: IPageView, index: number): void { + if (this._wizard && this._pages[index] === page) { + this._pages = this._pages.splice(index); + this._wizard.removePage(index); + } + } + + /** * * @param title Creates anew wizard @@ -57,9 +70,21 @@ export class WizardView extends MainViewBase { this._wizard.onPageChanged(async (info) => { this.onWizardPageChanged(info); }); + return this._wizard; } + public async validate(pageInfo: azdata.window.WizardPageChangeInfo): Promise { + if (pageInfo.lastPage !== undefined) { + let idxLast = pageInfo.lastPage; + let lastPage = this._pages[idxLast]; + if (lastPage && lastPage.validate) { + return await lastPage.validate(); + } + } + return true; + } + private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) { let idxLast = pageInfo.lastPage; let lastPage = this._pages[idxLast]; @@ -73,4 +98,8 @@ export class WizardView extends MainViewBase { page.onEnter(); } } + + public get wizard(): azdata.window.Wizard | undefined { + return this._wizard; + } }