ML - Added a link in models page to run predict on a model (#13124)

* Added a link in models page to run predict on a model

* Updated the icons
This commit is contained in:
Leila Lali
2020-10-29 16:37:23 -07:00
committed by GitHub
parent e31d563f61
commit d450588e39
10 changed files with 121 additions and 36 deletions

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 16 16" style="enable-background:new 0 0 16 16;" xml:space="preserve">
<style type="text/css">
.st0{fill:#C5C5C5;}
.st1{fill:none;stroke:#C5C5C5;stroke-miterlimit:10;}
</style>
<g id="ff21850e-cc3c-49a5-89aa-1e729237460d">
<circle class="st0" cx="1.5" cy="10.5" r="1.5"/>
<circle class="st0" cx="9.5" cy="10.5" r="1.5"/>
<circle class="st0" cx="5.4" cy="6.5" r="1.5"/>
<path class="st0" d="M14.5,4C14.8,4,15,4.2,15,4.5S14.8,5,14.5,5S14,4.8,14,4.5S14.2,4,14.5,4 M14.5,3C13.7,3,13,3.7,13,4.5
S13.7,6,14.5,6S16,5.3,16,4.5S15.3,3,14.5,3z"/>
<line class="st1" x1="9.5" y1="10.5" x2="14.3" y2="5"/>
<line class="st1" x1="1.5" y1="10.5" x2="5.4" y2="6"/>
<line class="st1" x1="9.3" y1="10.5" x2="5.4" y2="6"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1000 B

View File

@@ -0,0 +1,13 @@
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16">
<title>Predict</title>
<g id="ff21850e-cc3c-49a5-89aa-1e729237460d" data-name="Labels">
<circle cx="1.5" cy="10.5" r="1.5"/>
<circle cx="9.5" cy="10.5" r="1.5"/>
<circle cx="5.4" cy="6.5" r="1.5"/>
<path d="M14.5,4a.5.5,0,1,1-.5.5.5.5,0,0,1,.5-.5m0-1A1.5,1.5,0,1,0,16,4.5,1.5,1.5,0,0,0,14.5,3Z"/>
<line x1="1.5" y1="10.5" x2="5" y2="6.5" fill="none" stroke="#000" stroke-miterlimit="10"/>
<line x1="9.5" y1="10.5" x2="14.3" y2="5" fill="none" stroke="#000" stroke-miterlimit="10"/>
<line x1="1.5" y1="10.5" x2="5.427" y2="6" fill="none" stroke="#000" stroke-miterlimit="10"/>
<line x1="9.327" y1="10.5" x2="5.4" y2="6" fill="none" stroke="#000" stroke-miterlimit="10"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 804 B

View File

@@ -135,7 +135,6 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
try {
if (this._tableSelectionComponent && this._dataTable) {
await this._tableSelectionComponent.refresh();
await this._dataTable.refresh();
this.refreshComponents();
}
} catch (err) {

View File

@@ -6,7 +6,7 @@
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../../common/constants';
import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase';
import { ModelViewBase, DeleteModelEventName, EditModelEventName, PredictWizardEventName } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { IDataComponent, IComponentSettings } from '../../interfaces';
@@ -20,7 +20,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined;
private _selectedModel: ImportedModel[] = [];
private _selectedModels: ImportedModel[] = [];
private _loader: azdata.LoadingComponent | undefined;
private _downloadedFile: ModelArtifact | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
@@ -121,6 +121,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
},
}
);
columns.push(
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
);
}
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
@@ -157,6 +171,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
return this._loader;
}
public set selectedModels(value: ImportedModel[]) {
this._selectedModels = value;
}
/**
* Loads the data in the component
*/
@@ -167,8 +185,6 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
if (this.importTable) {
models = await this.listModels(this.importTable);
} else {
this.showErrorMessage('No import table');
}
let tableData: any[][] = [];
@@ -218,13 +234,13 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
let onSelectItem = (checked: boolean) => {
if (!this._settings.multiSelect) {
this._selectedModel = [];
this._selectedModels = [];
}
const foundItem = this._selectedModel.find(x => x === model);
const foundItem = this._selectedModels.find(x => x === model);
if (checked && !foundItem) {
this._selectedModel.push(model);
this._selectedModels.push(model);
} else if (foundItem) {
this._selectedModel = this._selectedModel.filter(x => x !== model);
this._selectedModels = this._selectedModels.filter(x => x !== model);
}
this.onModelSelected();
};
@@ -234,7 +250,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
value: model.id,
width: 15,
height: 15,
checked: false
checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id)
}).component();
checkbox.onChanged(() => {
onSelectItem(checkbox.checked || false);
@@ -246,7 +262,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
value: model.id,
width: 15,
height: 15,
checked: false
checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id)
}).component();
radioButton.onDidClick(() => {
onSelectItem(radioButton.checked || false);
@@ -259,6 +275,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
private createEditButtons(model: ImportedModel): azdata.Component[] | undefined {
let dropButton: azdata.ButtonComponent | undefined = undefined;
let predictButton: azdata.ButtonComponent | undefined = undefined;
let editButton: azdata.ButtonComponent | undefined = undefined;
if (this._modelBuilder && this._settings.editable) {
dropButton = this._modelBuilder.button().withProperties({
@@ -268,8 +285,8 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
width: 15,
height: 15
width: 16,
height: 16
}).component();
dropButton.onDidClick(async () => {
try {
@@ -284,6 +301,19 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`);
}
});
predictButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.predictModel,
iconPath: {
dark: this.asAbsolutePath('images/dark/predict_inverse.svg'),
light: this.asAbsolutePath('images/light/predict.svg')
},
width: 16,
height: 16
}).component();
predictButton.onDidClick(async () => {
await this.sendDataRequest(PredictWizardEventName, [model]);
});
editButton = this._modelBuilder.button().withProperties({
label: '',
@@ -292,14 +322,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
width: 15,
height: 15
width: 16,
height: 16
}).component();
editButton.onDidClick(async () => {
await this.sendDataRequest(EditModelEventName, model);
});
}
return editButton && dropButton ? [editButton, dropButton] : undefined;
return editButton && dropButton && predictButton ? [editButton, dropButton, predictButton] : undefined;
}
private async onModelSelected(): Promise<void> {
@@ -314,7 +344,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
* Returns selected data
*/
public get data(): ImportedModel[] | undefined {
return this._selectedModel;
return this._selectedModels;
}
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {

View File

@@ -12,6 +12,7 @@ import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import * as utils from '../../common/utils';
import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
import { ImportedModel } from '../../modelManagement/interfaces';
/**
* View to pick model source
@@ -25,7 +26,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
public azureModelsComponent: AzureModelsComponent | undefined;
public registeredModelsComponent: CurrentModelsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true,
private _selectedModels?: ImportedModel[] | undefined) {
super(apiWrapper, parent.root, parent);
}
@@ -46,6 +48,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
editable: false
});
this.registeredModelsComponent.registerComponent(modelBuilder);
// Mark a model in the list as selected
if (this._selectedModels && this.registeredModelsComponent.modelTable) {
this.registeredModelsComponent.modelTable.selectedModels = this._selectedModels;
}
this._form = this._formBuilder.component();
return this._form;
}

View File

@@ -17,7 +17,7 @@ import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName, PredictWizardEventName
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { ImportModelWizard } from './manageModels/importModelWizard';
@@ -93,20 +93,25 @@ export class ModelManagementController extends ControllerBase {
/**
* Opens the wizard for prediction
*/
public async predictModel(): Promise<ModelViewBase | undefined> {
public async predictModel(models?: ImportedModel[] | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase | undefined> {
const onnxSupported = await this._predictService.serverSupportOnnxModel();
controller = controller || this;
apiWrapper = apiWrapper || this._apiWrapper;
root = root || this._root;
const onnxSupported = await controller._predictService.serverSupportOnnxModel();
if (onnxSupported) {
await this._deployedModelService.installDependencies();
let view = new PredictWizard(this._apiWrapper, this._root);
view.importTable = await this._deployedModelService.getRecentImportTable();
await controller._deployedModelService.installDependencies();
let view = new PredictWizard(apiWrapper, root, parent, models);
view.importTable = await controller._deployedModelService.getRecentImportTable();
this.registerEvents(view);
controller.registerEvents(view);
view.on(LoadModelParametersEventName, async (args) => {
const modelArtifact = await view.getModelFileName();
await this.executeAction(view, LoadModelParametersEventName, args, this.loadModelParameters, this._deployedModelService,
modelArtifact?.filePath);
if (controller) {
const modelArtifact = await view.getModelFileName();
await controller.executeAction(view, LoadModelParametersEventName, args, controller.loadModelParameters, controller._deployedModelService,
modelArtifact?.filePath);
}
});
// Open view
@@ -163,6 +168,10 @@ export class ModelManagementController extends ControllerBase {
const importTable = <DatabaseTable>args;
await this.executeAction(view, RegisterModelEventName, args, this.importModel, importTable, view, this, this._apiWrapper, this._root);
});
view.on(PredictWizardEventName, async (args) => {
const models = <ImportedModel[] | undefined>args;
await this.executeAction(view, PredictWizardEventName, args, this.predictModel, models, view, this, this._apiWrapper, this._root);
});
view.on(EditModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, EditModelEventName, args, this.editModel, model, view, this, this._apiWrapper, this._root);

View File

@@ -114,11 +114,14 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
}).component();
this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => {
this._sourceType = this.convertSourceIdToEnum(cardId);
if (this._selectedSourceLabel) {
this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
const selectedValue = this.convertSourceIdToEnum(cardId);
if (selectedValue !== this._sourceType) {
this._sourceType = selectedValue;
if (this._selectedSourceLabel) {
this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
}
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
}
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
}));
this._form = modelBuilder.formContainer().withFormItems([{

View File

@@ -61,6 +61,7 @@ export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
export const PredictModelEventName = 'predictModel';
export const PredictWizardEventName = 'predictWizard';
export const RegisterModelEventName = 'registerModel';
export const EditModelEventName = 'editModel';
export const UpdateModelEventName = 'updateModel';
@@ -108,7 +109,8 @@ export abstract class ModelViewBase extends ViewBase {
EditModelEventName,
UpdateModelEventName,
DeleteModelEventName,
SignInToAzureEventName]);
SignInToAzureEventName,
PredictWizardEventName]);
}
/**

View File

@@ -94,7 +94,6 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
this.inputColumnsComponent.modelParameters = modelParameters;
this.outputColumnsComponent.modelParameters = modelParameters;
await this.inputColumnsComponent.refresh();
await this.outputColumnsComponent.refresh();
}
} catch (error) {

View File

@@ -31,7 +31,8 @@ export class PredictWizard extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string,
parent?: ModelViewBase) {
parent?: ModelViewBase,
private _selectedModels?: ImportedModel[] | undefined) {
super(apiWrapper, root);
this._parentView = parent;
this.modelActionType = ModelActionType.Predict;
@@ -44,7 +45,7 @@ export class PredictWizard extends ModelViewBase {
this.modelSourceType = ModelSourceType.RegisteredModels;
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false);
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false, this._selectedModels);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
@@ -83,6 +84,9 @@ export class PredictWizard extends ModelViewBase {
});
await wizard.open();
if (this._selectedModels) {
await wizard.setCurrentPage(wizard.pages.length - 1);
}
}
private onLoading(): void {