mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
ML - Added a link in models page to run predict on a model (#13124)
* Added a link in models page to run predict on a model * Updated the icons
This commit is contained in:
19
extensions/machine-learning/images/dark/predict_inverse.svg
Normal file
19
extensions/machine-learning/images/dark/predict_inverse.svg
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- Generator: Adobe Illustrator 24.2.1, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
|
||||
<svg version="1.1" id="Layer_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 16 16" style="enable-background:new 0 0 16 16;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill:#C5C5C5;}
|
||||
.st1{fill:none;stroke:#C5C5C5;stroke-miterlimit:10;}
|
||||
</style>
|
||||
<g id="ff21850e-cc3c-49a5-89aa-1e729237460d">
|
||||
<circle class="st0" cx="1.5" cy="10.5" r="1.5"/>
|
||||
<circle class="st0" cx="9.5" cy="10.5" r="1.5"/>
|
||||
<circle class="st0" cx="5.4" cy="6.5" r="1.5"/>
|
||||
<path class="st0" d="M14.5,4C14.8,4,15,4.2,15,4.5S14.8,5,14.5,5S14,4.8,14,4.5S14.2,4,14.5,4 M14.5,3C13.7,3,13,3.7,13,4.5
|
||||
S13.7,6,14.5,6S16,5.3,16,4.5S15.3,3,14.5,3z"/>
|
||||
<line class="st1" x1="9.5" y1="10.5" x2="14.3" y2="5"/>
|
||||
<line class="st1" x1="1.5" y1="10.5" x2="5.4" y2="6"/>
|
||||
<line class="st1" x1="9.3" y1="10.5" x2="5.4" y2="6"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1000 B |
13
extensions/machine-learning/images/light/predict.svg
Normal file
13
extensions/machine-learning/images/light/predict.svg
Normal file
@@ -0,0 +1,13 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16">
|
||||
<title>Predict</title>
|
||||
<g id="ff21850e-cc3c-49a5-89aa-1e729237460d" data-name="Labels">
|
||||
<circle cx="1.5" cy="10.5" r="1.5"/>
|
||||
<circle cx="9.5" cy="10.5" r="1.5"/>
|
||||
<circle cx="5.4" cy="6.5" r="1.5"/>
|
||||
<path d="M14.5,4a.5.5,0,1,1-.5.5.5.5,0,0,1,.5-.5m0-1A1.5,1.5,0,1,0,16,4.5,1.5,1.5,0,0,0,14.5,3Z"/>
|
||||
<line x1="1.5" y1="10.5" x2="5" y2="6.5" fill="none" stroke="#000" stroke-miterlimit="10"/>
|
||||
<line x1="9.5" y1="10.5" x2="14.3" y2="5" fill="none" stroke="#000" stroke-miterlimit="10"/>
|
||||
<line x1="1.5" y1="10.5" x2="5.427" y2="6" fill="none" stroke="#000" stroke-miterlimit="10"/>
|
||||
<line x1="9.327" y1="10.5" x2="5.4" y2="6" fill="none" stroke="#000" stroke-miterlimit="10"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 804 B |
@@ -135,7 +135,6 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
||||
try {
|
||||
if (this._tableSelectionComponent && this._dataTable) {
|
||||
await this._tableSelectionComponent.refresh();
|
||||
await this._dataTable.refresh();
|
||||
this.refreshComponents();
|
||||
}
|
||||
} catch (err) {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase';
|
||||
import { ModelViewBase, DeleteModelEventName, EditModelEventName, PredictWizardEventName } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
import { IDataComponent, IComponentSettings } from '../../interfaces';
|
||||
@@ -20,7 +20,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
private _modelBuilder: azdata.ModelBuilder | undefined;
|
||||
private _selectedModel: ImportedModel[] = [];
|
||||
private _selectedModels: ImportedModel[] = [];
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
private _downloadedFile: ModelArtifact | undefined;
|
||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
@@ -121,6 +121,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
},
|
||||
}
|
||||
);
|
||||
columns.push(
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
@@ -157,6 +171,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
public set selectedModels(value: ImportedModel[]) {
|
||||
this._selectedModels = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the data in the component
|
||||
*/
|
||||
@@ -167,8 +185,6 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
|
||||
if (this.importTable) {
|
||||
models = await this.listModels(this.importTable);
|
||||
} else {
|
||||
this.showErrorMessage('No import table');
|
||||
}
|
||||
let tableData: any[][] = [];
|
||||
|
||||
@@ -218,13 +234,13 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
|
||||
let onSelectItem = (checked: boolean) => {
|
||||
if (!this._settings.multiSelect) {
|
||||
this._selectedModel = [];
|
||||
this._selectedModels = [];
|
||||
}
|
||||
const foundItem = this._selectedModel.find(x => x === model);
|
||||
const foundItem = this._selectedModels.find(x => x === model);
|
||||
if (checked && !foundItem) {
|
||||
this._selectedModel.push(model);
|
||||
this._selectedModels.push(model);
|
||||
} else if (foundItem) {
|
||||
this._selectedModel = this._selectedModel.filter(x => x !== model);
|
||||
this._selectedModels = this._selectedModels.filter(x => x !== model);
|
||||
}
|
||||
this.onModelSelected();
|
||||
};
|
||||
@@ -234,7 +250,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
value: model.id,
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: false
|
||||
checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id)
|
||||
}).component();
|
||||
checkbox.onChanged(() => {
|
||||
onSelectItem(checkbox.checked || false);
|
||||
@@ -246,7 +262,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
value: model.id,
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: false
|
||||
checked: this._selectedModels && this._selectedModels.find(x => x.id === model.id)
|
||||
}).component();
|
||||
radioButton.onDidClick(() => {
|
||||
onSelectItem(radioButton.checked || false);
|
||||
@@ -259,6 +275,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
|
||||
private createEditButtons(model: ImportedModel): azdata.Component[] | undefined {
|
||||
let dropButton: azdata.ButtonComponent | undefined = undefined;
|
||||
let predictButton: azdata.ButtonComponent | undefined = undefined;
|
||||
let editButton: azdata.ButtonComponent | undefined = undefined;
|
||||
if (this._modelBuilder && this._settings.editable) {
|
||||
dropButton = this._modelBuilder.button().withProperties({
|
||||
@@ -268,8 +285,8 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/delete.svg')
|
||||
},
|
||||
width: 15,
|
||||
height: 15
|
||||
width: 16,
|
||||
height: 16
|
||||
}).component();
|
||||
dropButton.onDidClick(async () => {
|
||||
try {
|
||||
@@ -284,6 +301,19 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`);
|
||||
}
|
||||
});
|
||||
predictButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
title: constants.predictModel,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/predict_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/predict.svg')
|
||||
},
|
||||
width: 16,
|
||||
height: 16
|
||||
}).component();
|
||||
predictButton.onDidClick(async () => {
|
||||
await this.sendDataRequest(PredictWizardEventName, [model]);
|
||||
});
|
||||
|
||||
editButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
@@ -292,14 +322,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/edit.svg')
|
||||
},
|
||||
width: 15,
|
||||
height: 15
|
||||
width: 16,
|
||||
height: 16
|
||||
}).component();
|
||||
editButton.onDidClick(async () => {
|
||||
await this.sendDataRequest(EditModelEventName, model);
|
||||
});
|
||||
}
|
||||
return editButton && dropButton ? [editButton, dropButton] : undefined;
|
||||
return editButton && dropButton && predictButton ? [editButton, dropButton, predictButton] : undefined;
|
||||
}
|
||||
|
||||
private async onModelSelected(): Promise<void> {
|
||||
@@ -314,7 +344,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ImportedModel[] | undefined {
|
||||
return this._selectedModel;
|
||||
return this._selectedModels;
|
||||
}
|
||||
|
||||
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
|
||||
|
||||
@@ -12,6 +12,7 @@ import { LocalModelsComponent } from './localModelsComponent';
|
||||
import { AzureModelsComponent } from './azureModelsComponent';
|
||||
import * as utils from '../../common/utils';
|
||||
import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
|
||||
import { ImportedModel } from '../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
@@ -25,7 +26,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||
public registeredModelsComponent: CurrentModelsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true,
|
||||
private _selectedModels?: ImportedModel[] | undefined) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
@@ -46,6 +48,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
||||
editable: false
|
||||
});
|
||||
this.registeredModelsComponent.registerComponent(modelBuilder);
|
||||
|
||||
// Mark a model in the list as selected
|
||||
if (this._selectedModels && this.registeredModelsComponent.modelTable) {
|
||||
this.registeredModelsComponent.modelTable.selectedModels = this._selectedModels;
|
||||
}
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import {
|
||||
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
|
||||
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName
|
||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName, PredictWizardEventName
|
||||
} from './modelViewBase';
|
||||
import { ControllerBase } from '../controllerBase';
|
||||
import { ImportModelWizard } from './manageModels/importModelWizard';
|
||||
@@ -93,20 +93,25 @@ export class ModelManagementController extends ControllerBase {
|
||||
/**
|
||||
* Opens the wizard for prediction
|
||||
*/
|
||||
public async predictModel(): Promise<ModelViewBase | undefined> {
|
||||
public async predictModel(models?: ImportedModel[] | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase | undefined> {
|
||||
|
||||
const onnxSupported = await this._predictService.serverSupportOnnxModel();
|
||||
controller = controller || this;
|
||||
apiWrapper = apiWrapper || this._apiWrapper;
|
||||
root = root || this._root;
|
||||
const onnxSupported = await controller._predictService.serverSupportOnnxModel();
|
||||
if (onnxSupported) {
|
||||
await this._deployedModelService.installDependencies();
|
||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||
view.importTable = await this._deployedModelService.getRecentImportTable();
|
||||
await controller._deployedModelService.installDependencies();
|
||||
let view = new PredictWizard(apiWrapper, root, parent, models);
|
||||
view.importTable = await controller._deployedModelService.getRecentImportTable();
|
||||
|
||||
this.registerEvents(view);
|
||||
controller.registerEvents(view);
|
||||
|
||||
view.on(LoadModelParametersEventName, async (args) => {
|
||||
const modelArtifact = await view.getModelFileName();
|
||||
await this.executeAction(view, LoadModelParametersEventName, args, this.loadModelParameters, this._deployedModelService,
|
||||
modelArtifact?.filePath);
|
||||
if (controller) {
|
||||
const modelArtifact = await view.getModelFileName();
|
||||
await controller.executeAction(view, LoadModelParametersEventName, args, controller.loadModelParameters, controller._deployedModelService,
|
||||
modelArtifact?.filePath);
|
||||
}
|
||||
});
|
||||
|
||||
// Open view
|
||||
@@ -163,6 +168,10 @@ export class ModelManagementController extends ControllerBase {
|
||||
const importTable = <DatabaseTable>args;
|
||||
await this.executeAction(view, RegisterModelEventName, args, this.importModel, importTable, view, this, this._apiWrapper, this._root);
|
||||
});
|
||||
view.on(PredictWizardEventName, async (args) => {
|
||||
const models = <ImportedModel[] | undefined>args;
|
||||
await this.executeAction(view, PredictWizardEventName, args, this.predictModel, models, view, this, this._apiWrapper, this._root);
|
||||
});
|
||||
view.on(EditModelEventName, async (args) => {
|
||||
const model = <ImportedModel>args;
|
||||
await this.executeAction(view, EditModelEventName, args, this.editModel, model, view, this, this._apiWrapper, this._root);
|
||||
|
||||
@@ -114,11 +114,14 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
}).component();
|
||||
|
||||
this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => {
|
||||
this._sourceType = this.convertSourceIdToEnum(cardId);
|
||||
if (this._selectedSourceLabel) {
|
||||
this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
|
||||
const selectedValue = this.convertSourceIdToEnum(cardId);
|
||||
if (selectedValue !== this._sourceType) {
|
||||
this._sourceType = selectedValue;
|
||||
if (this._selectedSourceLabel) {
|
||||
this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
|
||||
}
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
}
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
}));
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
|
||||
@@ -61,6 +61,7 @@ export const RegisterAzureModelEventName = 'registerAzureLocalModel';
|
||||
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
|
||||
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
|
||||
export const PredictModelEventName = 'predictModel';
|
||||
export const PredictWizardEventName = 'predictWizard';
|
||||
export const RegisterModelEventName = 'registerModel';
|
||||
export const EditModelEventName = 'editModel';
|
||||
export const UpdateModelEventName = 'updateModel';
|
||||
@@ -108,7 +109,8 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
EditModelEventName,
|
||||
UpdateModelEventName,
|
||||
DeleteModelEventName,
|
||||
SignInToAzureEventName]);
|
||||
SignInToAzureEventName,
|
||||
PredictWizardEventName]);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -94,7 +94,6 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
|
||||
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
|
||||
this.inputColumnsComponent.modelParameters = modelParameters;
|
||||
this.outputColumnsComponent.modelParameters = modelParameters;
|
||||
await this.inputColumnsComponent.refresh();
|
||||
await this.outputColumnsComponent.refresh();
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
@@ -31,7 +31,8 @@ export class PredictWizard extends ModelViewBase {
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
root: string,
|
||||
parent?: ModelViewBase) {
|
||||
parent?: ModelViewBase,
|
||||
private _selectedModels?: ImportedModel[] | undefined) {
|
||||
super(apiWrapper, root);
|
||||
this._parentView = parent;
|
||||
this.modelActionType = ModelActionType.Predict;
|
||||
@@ -44,7 +45,7 @@ export class PredictWizard extends ModelViewBase {
|
||||
this.modelSourceType = ModelSourceType.RegisteredModels;
|
||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
|
||||
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
|
||||
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false);
|
||||
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false, this._selectedModels);
|
||||
this.wizardView = new WizardView(this._apiWrapper);
|
||||
|
||||
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
|
||||
@@ -83,6 +84,9 @@ export class PredictWizard extends ModelViewBase {
|
||||
});
|
||||
|
||||
await wizard.open();
|
||||
if (this._selectedModels) {
|
||||
await wizard.setCurrentPage(wizard.pages.length - 1);
|
||||
}
|
||||
}
|
||||
|
||||
private onLoading(): void {
|
||||
|
||||
Reference in New Issue
Block a user