mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-07 17:23:56 -05:00
ML extension - Improving predict parameter mapping experience (#10264)
This commit is contained in:
@@ -40,19 +40,19 @@ export class ModelManagementController extends ControllerBase {
|
||||
apiWrapper: ApiWrapper,
|
||||
private _root: string,
|
||||
private _amlService: AzureModelRegistryService,
|
||||
private _registeredModelService: DeployedModelService,
|
||||
private _deployedModelService: DeployedModelService,
|
||||
private _predictService: PredictService) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the dialog for model registration
|
||||
* Opens the dialog for model import
|
||||
* @param parent parent if the view is opened from another view
|
||||
* @param controller controller
|
||||
* @param apiWrapper apiWrapper
|
||||
* @param root root folder path
|
||||
*/
|
||||
public async registerModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
|
||||
public async importModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
|
||||
controller = controller || this;
|
||||
apiWrapper = apiWrapper || this._apiWrapper;
|
||||
root = root || this._root;
|
||||
@@ -60,7 +60,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
if (importTable) {
|
||||
view.importTable = importTable;
|
||||
} else {
|
||||
view.importTable = await controller._registeredModelService.getRecentImportTable();
|
||||
view.importTable = await controller._deployedModelService.getRecentImportTable();
|
||||
}
|
||||
|
||||
controller.registerEvents(view);
|
||||
@@ -93,23 +93,31 @@ export class ModelManagementController extends ControllerBase {
|
||||
/**
|
||||
* Opens the wizard for prediction
|
||||
*/
|
||||
public async predictModel(): Promise<ModelViewBase> {
|
||||
public async predictModel(): Promise<ModelViewBase | undefined> {
|
||||
|
||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||
view.importTable = await this._registeredModelService.getRecentImportTable();
|
||||
const onnxSupported = await this._predictService.serverSupportOnnxModel();
|
||||
if (onnxSupported) {
|
||||
await this._deployedModelService.installDependencies();
|
||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||
view.importTable = await this._deployedModelService.getRecentImportTable();
|
||||
|
||||
this.registerEvents(view);
|
||||
view.on(LoadModelParametersEventName, async () => {
|
||||
const modelArtifact = await view.getModelFileName();
|
||||
await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
|
||||
modelArtifact?.filePath);
|
||||
});
|
||||
this.registerEvents(view);
|
||||
|
||||
// Open view
|
||||
//
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
view.on(LoadModelParametersEventName, async (args) => {
|
||||
const modelArtifact = await view.getModelFileName();
|
||||
await this.executeAction(view, LoadModelParametersEventName, args, this.loadModelParameters, this._deployedModelService,
|
||||
modelArtifact?.filePath);
|
||||
});
|
||||
|
||||
// Open view
|
||||
//
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
} else {
|
||||
this._apiWrapper.showErrorMessage(constants.onnxNotSupportedError);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -122,99 +130,99 @@ export class ModelManagementController extends ControllerBase {
|
||||
// Register events
|
||||
//
|
||||
super.registerEvents(view);
|
||||
view.on(ListAccountsEventName, async () => {
|
||||
await this.executeAction(view, ListAccountsEventName, this.getAzureAccounts, this._amlService);
|
||||
view.on(ListAccountsEventName, async (args) => {
|
||||
await this.executeAction(view, ListAccountsEventName, args, this.getAzureAccounts, this._amlService);
|
||||
});
|
||||
view.on(ListSubscriptionsEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListSubscriptionsEventName, this.getAzureSubscriptions, this._amlService, azureArgs.account);
|
||||
view.on(ListSubscriptionsEventName, async (args) => {
|
||||
let azureArgs = <AzureResourceEventArgs>args;
|
||||
await this.executeAction(view, ListSubscriptionsEventName, args, this.getAzureSubscriptions, this._amlService, azureArgs.account);
|
||||
});
|
||||
view.on(ListWorkspacesEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListWorkspacesEventName, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group);
|
||||
view.on(ListWorkspacesEventName, async (args) => {
|
||||
let azureArgs = <AzureResourceEventArgs>args;
|
||||
await this.executeAction(view, ListWorkspacesEventName, args, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group);
|
||||
});
|
||||
view.on(ListGroupsEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListGroupsEventName, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription);
|
||||
view.on(ListGroupsEventName, async (args) => {
|
||||
let azureArgs = <AzureResourceEventArgs>args;
|
||||
await this.executeAction(view, ListGroupsEventName, args, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription);
|
||||
});
|
||||
view.on(ListAzureModelsEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService
|
||||
view.on(ListAzureModelsEventName, async (args) => {
|
||||
let azureArgs = <AzureResourceEventArgs>args;
|
||||
await this.executeAction(view, ListAzureModelsEventName, args, this.getAzureModels, this._amlService
|
||||
, azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace);
|
||||
});
|
||||
view.on(ListModelsEventName, async (args) => {
|
||||
const table = <DatabaseTable>args;
|
||||
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService, table);
|
||||
await this.executeAction(view, ListModelsEventName, args, this.getRegisteredModels, this._deployedModelService, table);
|
||||
});
|
||||
view.on(RegisterLocalModelEventName, async (arg) => {
|
||||
let models = <ModelViewData[]>arg;
|
||||
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models);
|
||||
view.on(RegisterLocalModelEventName, async (args) => {
|
||||
let models = <ModelViewData[]>args;
|
||||
await this.executeAction(view, RegisterLocalModelEventName, args, this.registerLocalModel, this._deployedModelService, models);
|
||||
view.refresh();
|
||||
});
|
||||
view.on(RegisterModelEventName, async (args) => {
|
||||
const importTable = <DatabaseTable>args;
|
||||
await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root);
|
||||
await this.executeAction(view, RegisterModelEventName, args, this.importModel, importTable, view, this, this._apiWrapper, this._root);
|
||||
});
|
||||
view.on(EditModelEventName, async (args) => {
|
||||
const model = <ImportedModel>args;
|
||||
await this.executeAction(view, EditModelEventName, this.editModel, model, view, this, this._apiWrapper, this._root);
|
||||
await this.executeAction(view, EditModelEventName, args, this.editModel, model, view, this, this._apiWrapper, this._root);
|
||||
});
|
||||
view.on(UpdateModelEventName, async (args) => {
|
||||
const model = <ImportedModel>args;
|
||||
await this.executeAction(view, UpdateModelEventName, this.updateModel, this._registeredModelService, model);
|
||||
await this.executeAction(view, UpdateModelEventName, args, this.updateModel, this._deployedModelService, model);
|
||||
});
|
||||
view.on(DeleteModelEventName, async (args) => {
|
||||
const model = <ImportedModel>args;
|
||||
await this.executeAction(view, DeleteModelEventName, this.deleteModel, this._registeredModelService, model);
|
||||
await this.executeAction(view, DeleteModelEventName, args, this.deleteModel, this._deployedModelService, model);
|
||||
});
|
||||
view.on(RegisterAzureModelEventName, async (arg) => {
|
||||
let models = <ModelViewData[]>arg;
|
||||
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
||||
view.on(RegisterAzureModelEventName, async (args) => {
|
||||
let models = <ModelViewData[]>args;
|
||||
await this.executeAction(view, RegisterAzureModelEventName, args, this.registerAzureModel, this._amlService, this._deployedModelService,
|
||||
models);
|
||||
});
|
||||
view.on(DownloadAzureModelEventName, async (arg) => {
|
||||
let registerArgs = <AzureModelResource>arg;
|
||||
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
|
||||
view.on(DownloadAzureModelEventName, async (args) => {
|
||||
let registerArgs = <AzureModelResource>args;
|
||||
await this.executeAction(view, DownloadAzureModelEventName, args, this.downloadAzureModel, this._amlService,
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, async () => {
|
||||
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
|
||||
view.on(ListDatabaseNamesEventName, async (args) => {
|
||||
await this.executeAction(view, ListDatabaseNamesEventName, args, this.getDatabaseList, this._predictService);
|
||||
});
|
||||
view.on(ListTableNamesEventName, async (arg) => {
|
||||
let dbName = <string>arg;
|
||||
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
|
||||
view.on(ListTableNamesEventName, async (args) => {
|
||||
let dbName = <string>args;
|
||||
await this.executeAction(view, ListTableNamesEventName, args, this.getTableList, this._predictService, dbName);
|
||||
});
|
||||
view.on(ListColumnNamesEventName, async (arg) => {
|
||||
let tableColumnsArgs = <DatabaseTable>arg;
|
||||
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
|
||||
view.on(ListColumnNamesEventName, async (args) => {
|
||||
let tableColumnsArgs = <DatabaseTable>args;
|
||||
await this.executeAction(view, ListColumnNamesEventName, args, this.getTableColumnsList, this._predictService,
|
||||
tableColumnsArgs);
|
||||
});
|
||||
view.on(PredictModelEventName, async (arg) => {
|
||||
let predictArgs = <PredictModelEventArgs>arg;
|
||||
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
|
||||
view.on(PredictModelEventName, async (args) => {
|
||||
let predictArgs = <PredictModelEventArgs>args;
|
||||
await this.executeAction(view, PredictModelEventName, args, this.generatePredictScript, this._predictService,
|
||||
predictArgs, predictArgs.model, predictArgs.filePath);
|
||||
});
|
||||
view.on(DownloadRegisteredModelEventName, async (arg) => {
|
||||
let model = <ImportedModel>arg;
|
||||
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
|
||||
view.on(DownloadRegisteredModelEventName, async (args) => {
|
||||
let model = <ImportedModel>args;
|
||||
await this.executeAction(view, DownloadRegisteredModelEventName, args, this.downloadRegisteredModel, this._deployedModelService,
|
||||
model);
|
||||
});
|
||||
view.on(StoreImportTableEventName, async (arg) => {
|
||||
let importTable = <DatabaseTable>arg;
|
||||
await this.executeAction(view, StoreImportTableEventName, this.storeImportTable, this._registeredModelService,
|
||||
view.on(StoreImportTableEventName, async (args) => {
|
||||
let importTable = <DatabaseTable>args;
|
||||
await this.executeAction(view, StoreImportTableEventName, args, this.storeImportTable, this._deployedModelService,
|
||||
importTable);
|
||||
});
|
||||
view.on(VerifyImportTableEventName, async (arg) => {
|
||||
let importTable = <DatabaseTable>arg;
|
||||
await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService,
|
||||
view.on(VerifyImportTableEventName, async (args) => {
|
||||
let importTable = <DatabaseTable>args;
|
||||
await this.executeAction(view, VerifyImportTableEventName, args, this.verifyImportTable, this._deployedModelService,
|
||||
importTable);
|
||||
});
|
||||
view.on(SourceModelSelectedEventName, async (arg) => {
|
||||
view.modelSourceType = <ModelSourceType>arg;
|
||||
view.on(SourceModelSelectedEventName, async (args) => {
|
||||
view.modelSourceType = <ModelSourceType>args;
|
||||
await view.refresh();
|
||||
});
|
||||
view.on(SignInToAzureEventName, async () => {
|
||||
await this.executeAction(view, SignInToAzureEventName, this.signInToAzure, this._amlService);
|
||||
view.on(SignInToAzureEventName, async (args) => {
|
||||
await this.executeAction(view, SignInToAzureEventName, args, this.signInToAzure, this._amlService);
|
||||
await view.refresh();
|
||||
});
|
||||
}
|
||||
@@ -228,7 +236,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
if (importTable) {
|
||||
view.importTable = importTable;
|
||||
} else {
|
||||
view.importTable = await this._registeredModelService.getRecentImportTable();
|
||||
view.importTable = await this._deployedModelService.getRecentImportTable();
|
||||
}
|
||||
|
||||
// Register events
|
||||
|
||||
Reference in New Issue
Block a user