From d450588e392184804876f7ea827ca16310ec6015 Mon Sep 17 00:00:00 2001 From: Leila Lali Date: Thu, 29 Oct 2020 16:37:23 -0700 Subject: [PATCH] ML - Added a link in models page to run predict on a model (#13124) * Added a link in models page to run predict on a model * Updated the icons --- .../images/dark/predict_inverse.svg | 19 ++++++ .../machine-learning/images/light/predict.svg | 13 ++++ .../manageModels/currentModelsComponent.ts | 1 - .../models/manageModels/currentModelsTable.ts | 62 ++++++++++++++----- .../src/views/models/modelBrowsePage.ts | 9 ++- .../views/models/modelManagementController.ts | 29 ++++++--- .../src/views/models/modelSourcesComponent.ts | 11 ++-- .../src/views/models/modelViewBase.ts | 4 +- .../models/prediction/columnsSelectionPage.ts | 1 - .../views/models/prediction/predictWizard.ts | 8 ++- 10 files changed, 121 insertions(+), 36 deletions(-) create mode 100644 extensions/machine-learning/images/dark/predict_inverse.svg create mode 100644 extensions/machine-learning/images/light/predict.svg diff --git a/extensions/machine-learning/images/dark/predict_inverse.svg b/extensions/machine-learning/images/dark/predict_inverse.svg new file mode 100644 index 0000000000..59132b78d6 --- /dev/null +++ b/extensions/machine-learning/images/dark/predict_inverse.svg @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + diff --git a/extensions/machine-learning/images/light/predict.svg b/extensions/machine-learning/images/light/predict.svg new file mode 100644 index 0000000000..c0cafb59ab --- /dev/null +++ b/extensions/machine-learning/images/light/predict.svg @@ -0,0 +1,13 @@ + + Predict + + + + + + + + + + + diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts index 325c263930..393c75a026 100644 --- a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts +++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts @@ -135,7 +135,6 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { try { if (this._tableSelectionComponent && this._dataTable) { await this._tableSelectionComponent.refresh(); - await this._dataTable.refresh(); this.refreshComponents(); } } catch (err) { diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts index 4192cca97f..ea831d3b4d 100644 --- a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts +++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts @@ -6,7 +6,7 @@ import * as azdata from 'azdata'; import * as vscode from 'vscode'; import * as constants from '../../../common/constants'; -import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase'; +import { ModelViewBase, DeleteModelEventName, EditModelEventName, PredictWizardEventName } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; import { ImportedModel } from '../../../modelManagement/interfaces'; import { IDataComponent, IComponentSettings } from '../../interfaces'; @@ -20,7 +20,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< private _table: azdata.DeclarativeTableComponent | undefined; private _modelBuilder: azdata.ModelBuilder | undefined; - private _selectedModel: ImportedModel[] = []; + private _selectedModels: ImportedModel[] = []; private _loader: azdata.LoadingComponent | undefined; private _downloadedFile: ModelArtifact | undefined; private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter(); @@ -121,6 +121,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< }, } ); + columns.push( + { // Action + displayName: '', + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + } + ); } this._table = modelBuilder.declarativeTable() .withProperties( @@ -157,6 +171,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< return this._loader; } + public set selectedModels(value: ImportedModel[]) { + this._selectedModels = value; + } + /** * Loads the data in the component */ @@ -167,8 +185,6 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< if (this.importTable) { models = await this.listModels(this.importTable); - } else { - this.showErrorMessage('No import table'); } let tableData: any[][] = []; @@ -218,13 +234,13 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< let onSelectItem = (checked: boolean) => { if (!this._settings.multiSelect) { - this._selectedModel = []; + this._selectedModels = []; } - const foundItem = this._selectedModel.find(x => x === model); + const foundItem = this._selectedModels.find(x => x === model); if (checked && !foundItem) { - this._selectedModel.push(model); + this._selectedModels.push(model); } else if (foundItem) { - this._selectedModel = this._selectedModel.filter(x => x !== model); + this._selectedModels = this._selectedModels.filter(x => x !== model); } this.onModelSelected(); }; @@ -234,7 +250,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< value: model.id, width: 15, height: 15, - checked: false + checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id) }).component(); checkbox.onChanged(() => { onSelectItem(checkbox.checked || false); @@ -246,7 +262,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< value: model.id, width: 15, height: 15, - checked: false + checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id) }).component(); radioButton.onDidClick(() => { onSelectItem(radioButton.checked || false); @@ -259,6 +275,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< private createEditButtons(model: ImportedModel): azdata.Component[] | undefined { let dropButton: azdata.ButtonComponent | undefined = undefined; + let predictButton: azdata.ButtonComponent | undefined = undefined; let editButton: azdata.ButtonComponent | undefined = undefined; if (this._modelBuilder && this._settings.editable) { dropButton = this._modelBuilder.button().withProperties({ @@ -268,8 +285,8 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< dark: this.asAbsolutePath('images/dark/delete_inverse.svg'), light: this.asAbsolutePath('images/light/delete.svg') }, - width: 15, - height: 15 + width: 16, + height: 16 }).component(); dropButton.onDidClick(async () => { try { @@ -284,6 +301,19 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`); } }); + predictButton = this._modelBuilder.button().withProperties({ + label: '', + title: constants.predictModel, + iconPath: { + dark: this.asAbsolutePath('images/dark/predict_inverse.svg'), + light: this.asAbsolutePath('images/light/predict.svg') + }, + width: 16, + height: 16 + }).component(); + predictButton.onDidClick(async () => { + await this.sendDataRequest(PredictWizardEventName, [model]); + }); editButton = this._modelBuilder.button().withProperties({ label: '', @@ -292,14 +322,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< dark: this.asAbsolutePath('images/dark/edit_inverse.svg'), light: this.asAbsolutePath('images/light/edit.svg') }, - width: 15, - height: 15 + width: 16, + height: 16 }).component(); editButton.onDidClick(async () => { await this.sendDataRequest(EditModelEventName, model); }); } - return editButton && dropButton ? [editButton, dropButton] : undefined; + return editButton && dropButton && predictButton ? [editButton, dropButton, predictButton] : undefined; } private async onModelSelected(): Promise { @@ -314,7 +344,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< * Returns selected data */ public get data(): ImportedModel[] | undefined { - return this._selectedModel; + return this._selectedModels; } public async getDownloadedModel(): Promise { diff --git a/extensions/machine-learning/src/views/models/modelBrowsePage.ts b/extensions/machine-learning/src/views/models/modelBrowsePage.ts index af1f1c2a32..a8cfb45a3a 100644 --- a/extensions/machine-learning/src/views/models/modelBrowsePage.ts +++ b/extensions/machine-learning/src/views/models/modelBrowsePage.ts @@ -12,6 +12,7 @@ import { LocalModelsComponent } from './localModelsComponent'; import { AzureModelsComponent } from './azureModelsComponent'; import * as utils from '../../common/utils'; import { CurrentModelsComponent } from './manageModels/currentModelsComponent'; +import { ImportedModel } from '../../modelManagement/interfaces'; /** * View to pick model source @@ -25,7 +26,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo public azureModelsComponent: AzureModelsComponent | undefined; public registeredModelsComponent: CurrentModelsComponent | undefined; - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true, + private _selectedModels?: ImportedModel[] | undefined) { super(apiWrapper, parent.root, parent); } @@ -46,6 +48,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo editable: false }); this.registeredModelsComponent.registerComponent(modelBuilder); + + // Mark a model in the list as selected + if (this._selectedModels && this.registeredModelsComponent.modelTable) { + this.registeredModelsComponent.modelTable.selectedModels = this._selectedModels; + } this._form = this._formBuilder.component(); return this._form; } diff --git a/extensions/machine-learning/src/views/models/modelManagementController.ts b/extensions/machine-learning/src/views/models/modelManagementController.ts index f0aa6c5934..e993b1aa75 100644 --- a/extensions/machine-learning/src/views/models/modelManagementController.ts +++ b/extensions/machine-learning/src/views/models/modelManagementController.ts @@ -17,7 +17,7 @@ import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, - ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName + ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName, PredictWizardEventName } from './modelViewBase'; import { ControllerBase } from '../controllerBase'; import { ImportModelWizard } from './manageModels/importModelWizard'; @@ -93,20 +93,25 @@ export class ModelManagementController extends ControllerBase { /** * Opens the wizard for prediction */ - public async predictModel(): Promise { + public async predictModel(models?: ImportedModel[] | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise { - const onnxSupported = await this._predictService.serverSupportOnnxModel(); + controller = controller || this; + apiWrapper = apiWrapper || this._apiWrapper; + root = root || this._root; + const onnxSupported = await controller._predictService.serverSupportOnnxModel(); if (onnxSupported) { - await this._deployedModelService.installDependencies(); - let view = new PredictWizard(this._apiWrapper, this._root); - view.importTable = await this._deployedModelService.getRecentImportTable(); + await controller._deployedModelService.installDependencies(); + let view = new PredictWizard(apiWrapper, root, parent, models); + view.importTable = await controller._deployedModelService.getRecentImportTable(); - this.registerEvents(view); + controller.registerEvents(view); view.on(LoadModelParametersEventName, async (args) => { - const modelArtifact = await view.getModelFileName(); - await this.executeAction(view, LoadModelParametersEventName, args, this.loadModelParameters, this._deployedModelService, - modelArtifact?.filePath); + if (controller) { + const modelArtifact = await view.getModelFileName(); + await controller.executeAction(view, LoadModelParametersEventName, args, controller.loadModelParameters, controller._deployedModelService, + modelArtifact?.filePath); + } }); // Open view @@ -163,6 +168,10 @@ export class ModelManagementController extends ControllerBase { const importTable = args; await this.executeAction(view, RegisterModelEventName, args, this.importModel, importTable, view, this, this._apiWrapper, this._root); }); + view.on(PredictWizardEventName, async (args) => { + const models = args; + await this.executeAction(view, PredictWizardEventName, args, this.predictModel, models, view, this, this._apiWrapper, this._root); + }); view.on(EditModelEventName, async (args) => { const model = args; await this.executeAction(view, EditModelEventName, args, this.editModel, model, view, this, this._apiWrapper, this._root); diff --git a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts index c8598ef27f..8b860e6cc3 100644 --- a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts +++ b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts @@ -114,11 +114,14 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone }).component(); this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => { - this._sourceType = this.convertSourceIdToEnum(cardId); - if (this._selectedSourceLabel) { - this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType); + const selectedValue = this.convertSourceIdToEnum(cardId); + if (selectedValue !== this._sourceType) { + this._sourceType = selectedValue; + if (this._selectedSourceLabel) { + this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType); + } + this.sendRequest(SourceModelSelectedEventName, this._sourceType); } - this.sendRequest(SourceModelSelectedEventName, this._sourceType); })); this._form = modelBuilder.formContainer().withFormItems([{ diff --git a/extensions/machine-learning/src/views/models/modelViewBase.ts b/extensions/machine-learning/src/views/models/modelViewBase.ts index b10c3b5252..b4fb7aa18f 100644 --- a/extensions/machine-learning/src/views/models/modelViewBase.ts +++ b/extensions/machine-learning/src/views/models/modelViewBase.ts @@ -61,6 +61,7 @@ export const RegisterAzureModelEventName = 'registerAzureLocalModel'; export const DownloadAzureModelEventName = 'downloadAzureLocalModel'; export const DownloadRegisteredModelEventName = 'downloadRegisteredModel'; export const PredictModelEventName = 'predictModel'; +export const PredictWizardEventName = 'predictWizard'; export const RegisterModelEventName = 'registerModel'; export const EditModelEventName = 'editModel'; export const UpdateModelEventName = 'updateModel'; @@ -108,7 +109,8 @@ export abstract class ModelViewBase extends ViewBase { EditModelEventName, UpdateModelEventName, DeleteModelEventName, - SignInToAzureEventName]); + SignInToAzureEventName, + PredictWizardEventName]); } /** diff --git a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts index 79f8029807..abb2c899e7 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts @@ -94,7 +94,6 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) { this.inputColumnsComponent.modelParameters = modelParameters; this.outputColumnsComponent.modelParameters = modelParameters; - await this.inputColumnsComponent.refresh(); await this.outputColumnsComponent.refresh(); } } catch (error) { diff --git a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts index 5ca9c4d778..7d554c55e4 100644 --- a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts +++ b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts @@ -31,7 +31,8 @@ export class PredictWizard extends ModelViewBase { constructor( apiWrapper: ApiWrapper, root: string, - parent?: ModelViewBase) { + parent?: ModelViewBase, + private _selectedModels?: ImportedModel[] | undefined) { super(apiWrapper, root); this._parentView = parent; this.modelActionType = ModelActionType.Predict; @@ -44,7 +45,7 @@ export class PredictWizard extends ModelViewBase { this.modelSourceType = ModelSourceType.RegisteredModels; this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]); this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this); - this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false); + this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false, this._selectedModels); this.wizardView = new WizardView(this._apiWrapper); let wizard = this.wizardView.createWizard(constants.makePredictionTitle, @@ -83,6 +84,9 @@ export class PredictWizard extends ModelViewBase { }); await wizard.open(); + if (this._selectedModels) { + await wizard.setCurrentPage(wizard.pages.length - 1); + } } private onLoading(): void {