diff --git a/extensions/machine-learning/images/dark/predict_inverse.svg b/extensions/machine-learning/images/dark/predict_inverse.svg
new file mode 100644
index 0000000000..59132b78d6
--- /dev/null
+++ b/extensions/machine-learning/images/dark/predict_inverse.svg
@@ -0,0 +1,19 @@
+
+
+
diff --git a/extensions/machine-learning/images/light/predict.svg b/extensions/machine-learning/images/light/predict.svg
new file mode 100644
index 0000000000..c0cafb59ab
--- /dev/null
+++ b/extensions/machine-learning/images/light/predict.svg
@@ -0,0 +1,13 @@
+
diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts
index 325c263930..393c75a026 100644
--- a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts
+++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts
@@ -135,7 +135,6 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
try {
if (this._tableSelectionComponent && this._dataTable) {
await this._tableSelectionComponent.refresh();
- await this._dataTable.refresh();
this.refreshComponents();
}
} catch (err) {
diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts
index 4192cca97f..ea831d3b4d 100644
--- a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts
+++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts
@@ -6,7 +6,7 @@
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../../common/constants';
-import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase';
+import { ModelViewBase, DeleteModelEventName, EditModelEventName, PredictWizardEventName } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { IDataComponent, IComponentSettings } from '../../interfaces';
@@ -20,7 +20,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined;
- private _selectedModel: ImportedModel[] = [];
+ private _selectedModels: ImportedModel[] = [];
private _loader: azdata.LoadingComponent | undefined;
private _downloadedFile: ModelArtifact | undefined;
private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter();
@@ -121,6 +121,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
},
}
);
+ columns.push(
+ { // Action
+ displayName: '',
+ valueType: azdata.DeclarativeDataType.component,
+ isReadOnly: true,
+ width: 50,
+ headerCssStyles: {
+ ...constants.cssStyles.tableHeader
+ },
+ rowCssStyles: {
+ ...constants.cssStyles.tableRow
+ },
+ }
+ );
}
this._table = modelBuilder.declarativeTable()
.withProperties(
@@ -157,6 +171,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
return this._loader;
}
+ public set selectedModels(value: ImportedModel[]) {
+ this._selectedModels = value;
+ }
+
/**
* Loads the data in the component
*/
@@ -167,8 +185,6 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
if (this.importTable) {
models = await this.listModels(this.importTable);
- } else {
- this.showErrorMessage('No import table');
}
let tableData: any[][] = [];
@@ -218,13 +234,13 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
let onSelectItem = (checked: boolean) => {
if (!this._settings.multiSelect) {
- this._selectedModel = [];
+ this._selectedModels = [];
}
- const foundItem = this._selectedModel.find(x => x === model);
+ const foundItem = this._selectedModels.find(x => x === model);
if (checked && !foundItem) {
- this._selectedModel.push(model);
+ this._selectedModels.push(model);
} else if (foundItem) {
- this._selectedModel = this._selectedModel.filter(x => x !== model);
+ this._selectedModels = this._selectedModels.filter(x => x !== model);
}
this.onModelSelected();
};
@@ -234,7 +250,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
value: model.id,
width: 15,
height: 15,
- checked: false
+ checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id)
}).component();
checkbox.onChanged(() => {
onSelectItem(checkbox.checked || false);
@@ -246,7 +262,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
value: model.id,
width: 15,
height: 15,
- checked: false
+ checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id)
}).component();
radioButton.onDidClick(() => {
onSelectItem(radioButton.checked || false);
@@ -259,6 +275,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
private createEditButtons(model: ImportedModel): azdata.Component[] | undefined {
let dropButton: azdata.ButtonComponent | undefined = undefined;
+ let predictButton: azdata.ButtonComponent | undefined = undefined;
let editButton: azdata.ButtonComponent | undefined = undefined;
if (this._modelBuilder && this._settings.editable) {
dropButton = this._modelBuilder.button().withProperties({
@@ -268,8 +285,8 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
- width: 15,
- height: 15
+ width: 16,
+ height: 16
}).component();
dropButton.onDidClick(async () => {
try {
@@ -284,6 +301,19 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`);
}
});
+ predictButton = this._modelBuilder.button().withProperties({
+ label: '',
+ title: constants.predictModel,
+ iconPath: {
+ dark: this.asAbsolutePath('images/dark/predict_inverse.svg'),
+ light: this.asAbsolutePath('images/light/predict.svg')
+ },
+ width: 16,
+ height: 16
+ }).component();
+ predictButton.onDidClick(async () => {
+ await this.sendDataRequest(PredictWizardEventName, [model]);
+ });
editButton = this._modelBuilder.button().withProperties({
label: '',
@@ -292,14 +322,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
- width: 15,
- height: 15
+ width: 16,
+ height: 16
}).component();
editButton.onDidClick(async () => {
await this.sendDataRequest(EditModelEventName, model);
});
}
- return editButton && dropButton ? [editButton, dropButton] : undefined;
+ return editButton && dropButton && predictButton ? [editButton, dropButton, predictButton] : undefined;
}
private async onModelSelected(): Promise {
@@ -314,7 +344,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
* Returns selected data
*/
public get data(): ImportedModel[] | undefined {
- return this._selectedModel;
+ return this._selectedModels;
}
public async getDownloadedModel(): Promise {
diff --git a/extensions/machine-learning/src/views/models/modelBrowsePage.ts b/extensions/machine-learning/src/views/models/modelBrowsePage.ts
index af1f1c2a32..a8cfb45a3a 100644
--- a/extensions/machine-learning/src/views/models/modelBrowsePage.ts
+++ b/extensions/machine-learning/src/views/models/modelBrowsePage.ts
@@ -12,6 +12,7 @@ import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import * as utils from '../../common/utils';
import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
+import { ImportedModel } from '../../modelManagement/interfaces';
/**
* View to pick model source
@@ -25,7 +26,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
public azureModelsComponent: AzureModelsComponent | undefined;
public registeredModelsComponent: CurrentModelsComponent | undefined;
- constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
+ constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true,
+ private _selectedModels?: ImportedModel[] | undefined) {
super(apiWrapper, parent.root, parent);
}
@@ -46,6 +48,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
editable: false
});
this.registeredModelsComponent.registerComponent(modelBuilder);
+
+ // Mark a model in the list as selected
+ if (this._selectedModels && this.registeredModelsComponent.modelTable) {
+ this.registeredModelsComponent.modelTable.selectedModels = this._selectedModels;
+ }
this._form = this._formBuilder.component();
return this._form;
}
diff --git a/extensions/machine-learning/src/views/models/modelManagementController.ts b/extensions/machine-learning/src/views/models/modelManagementController.ts
index f0aa6c5934..e993b1aa75 100644
--- a/extensions/machine-learning/src/views/models/modelManagementController.ts
+++ b/extensions/machine-learning/src/views/models/modelManagementController.ts
@@ -17,7 +17,7 @@ import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
- ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName
+ ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName, PredictWizardEventName
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { ImportModelWizard } from './manageModels/importModelWizard';
@@ -93,20 +93,25 @@ export class ModelManagementController extends ControllerBase {
/**
* Opens the wizard for prediction
*/
- public async predictModel(): Promise {
+ public async predictModel(models?: ImportedModel[] | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise {
- const onnxSupported = await this._predictService.serverSupportOnnxModel();
+ controller = controller || this;
+ apiWrapper = apiWrapper || this._apiWrapper;
+ root = root || this._root;
+ const onnxSupported = await controller._predictService.serverSupportOnnxModel();
if (onnxSupported) {
- await this._deployedModelService.installDependencies();
- let view = new PredictWizard(this._apiWrapper, this._root);
- view.importTable = await this._deployedModelService.getRecentImportTable();
+ await controller._deployedModelService.installDependencies();
+ let view = new PredictWizard(apiWrapper, root, parent, models);
+ view.importTable = await controller._deployedModelService.getRecentImportTable();
- this.registerEvents(view);
+ controller.registerEvents(view);
view.on(LoadModelParametersEventName, async (args) => {
- const modelArtifact = await view.getModelFileName();
- await this.executeAction(view, LoadModelParametersEventName, args, this.loadModelParameters, this._deployedModelService,
- modelArtifact?.filePath);
+ if (controller) {
+ const modelArtifact = await view.getModelFileName();
+ await controller.executeAction(view, LoadModelParametersEventName, args, controller.loadModelParameters, controller._deployedModelService,
+ modelArtifact?.filePath);
+ }
});
// Open view
@@ -163,6 +168,10 @@ export class ModelManagementController extends ControllerBase {
const importTable = args;
await this.executeAction(view, RegisterModelEventName, args, this.importModel, importTable, view, this, this._apiWrapper, this._root);
});
+ view.on(PredictWizardEventName, async (args) => {
+ const models = args;
+ await this.executeAction(view, PredictWizardEventName, args, this.predictModel, models, view, this, this._apiWrapper, this._root);
+ });
view.on(EditModelEventName, async (args) => {
const model = args;
await this.executeAction(view, EditModelEventName, args, this.editModel, model, view, this, this._apiWrapper, this._root);
diff --git a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts
index c8598ef27f..8b860e6cc3 100644
--- a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts
+++ b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts
@@ -114,11 +114,14 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
}).component();
this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => {
- this._sourceType = this.convertSourceIdToEnum(cardId);
- if (this._selectedSourceLabel) {
- this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
+ const selectedValue = this.convertSourceIdToEnum(cardId);
+ if (selectedValue !== this._sourceType) {
+ this._sourceType = selectedValue;
+ if (this._selectedSourceLabel) {
+ this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
+ }
+ this.sendRequest(SourceModelSelectedEventName, this._sourceType);
}
- this.sendRequest(SourceModelSelectedEventName, this._sourceType);
}));
this._form = modelBuilder.formContainer().withFormItems([{
diff --git a/extensions/machine-learning/src/views/models/modelViewBase.ts b/extensions/machine-learning/src/views/models/modelViewBase.ts
index b10c3b5252..b4fb7aa18f 100644
--- a/extensions/machine-learning/src/views/models/modelViewBase.ts
+++ b/extensions/machine-learning/src/views/models/modelViewBase.ts
@@ -61,6 +61,7 @@ export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
export const PredictModelEventName = 'predictModel';
+export const PredictWizardEventName = 'predictWizard';
export const RegisterModelEventName = 'registerModel';
export const EditModelEventName = 'editModel';
export const UpdateModelEventName = 'updateModel';
@@ -108,7 +109,8 @@ export abstract class ModelViewBase extends ViewBase {
EditModelEventName,
UpdateModelEventName,
DeleteModelEventName,
- SignInToAzureEventName]);
+ SignInToAzureEventName,
+ PredictWizardEventName]);
}
/**
diff --git a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts
index 79f8029807..abb2c899e7 100644
--- a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts
+++ b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts
@@ -94,7 +94,6 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
this.inputColumnsComponent.modelParameters = modelParameters;
this.outputColumnsComponent.modelParameters = modelParameters;
- await this.inputColumnsComponent.refresh();
await this.outputColumnsComponent.refresh();
}
} catch (error) {
diff --git a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts
index 5ca9c4d778..7d554c55e4 100644
--- a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts
+++ b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts
@@ -31,7 +31,8 @@ export class PredictWizard extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string,
- parent?: ModelViewBase) {
+ parent?: ModelViewBase,
+ private _selectedModels?: ImportedModel[] | undefined) {
super(apiWrapper, root);
this._parentView = parent;
this.modelActionType = ModelActionType.Predict;
@@ -44,7 +45,7 @@ export class PredictWizard extends ModelViewBase {
this.modelSourceType = ModelSourceType.RegisteredModels;
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
- this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false);
+ this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false, this._selectedModels);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
@@ -83,6 +84,9 @@ export class PredictWizard extends ModelViewBase {
});
await wizard.open();
+ if (this._selectedModels) {
+ await wizard.setCurrentPage(wizard.pages.length - 1);
+ }
}
private onLoading(): void {