Machine Learning Services Extension - Predict wizard (#9450)

*MLS extension - Added predict wizard
This commit is contained in:
Leila Lali
2020-03-09 15:40:05 -07:00
committed by GitHub
parent b017634431
commit 3be3563b0d
37 changed files with 1501 additions and 219 deletions

View File

@@ -9,14 +9,23 @@ import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
import { RegisteredModelsDialog } from './registeredModelsDialog';
import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName } from './modelViewBase';
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { RegisterModelWizard } from './registerModelWizard';
import { RegisterModelWizard } from './registerModels/registerModelWizard';
import * as fs from 'fs';
import * as constants from '../../common/constants';
import { PredictWizard } from './prediction/predictWizard';
import { AzureModelResource } from '../interfaces';
import { PredictService } from '../../prediction/predictService';
/**
* Model management UI controller
@@ -30,7 +39,8 @@ export class ModelManagementController extends ControllerBase {
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
private _registeredModelService: RegisteredModelService) {
private _registeredModelService: RegisteredModelService,
private _predictService: PredictService) {
super(apiWrapper);
}
@@ -56,6 +66,23 @@ export class ModelManagementController extends ControllerBase {
return view;
}
/**
* Opens the wizard for prediction
*/
public async predictModel(): Promise<ModelViewBase> {
let view = new PredictWizard(this._apiWrapper, this._root);
this.registerEvents(view);
// Open view
//
view.open();
await view.refresh();
return view;
}
/**
* Register events in the main view
* @param view main view
@@ -102,6 +129,28 @@ export class ModelManagementController extends ControllerBase {
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
});
view.on(DownloadAzureModelEventName, async (arg) => {
let registerArgs = <AzureModelResource>arg;
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
});
view.on(ListDatabaseNamesEventName, async () => {
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
});
view.on(ListTableNamesEventName, async (arg) => {
let dbName = <string>arg;
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
});
view.on(ListColumnNamesEventName, async (arg) => {
let tableColumnsArgs = <DatabaseTable>arg;
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
tableColumnsArgs);
});
view.on(PredictModelEventName, async (arg) => {
let predictArgs = <PredictModelEventArgs>arg;
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
predictArgs, predictArgs.model, predictArgs.filePath);
});
view.on(SourceModelSelectedEventName, () => {
view.refresh();
});
@@ -158,7 +207,7 @@ export class ModelManagementController extends ControllerBase {
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise<void> {
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
if (filePath) {
await service.registerLocalModel(filePath, details);
} else {
@@ -175,7 +224,7 @@ export class ModelManagementController extends ControllerBase {
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined,
details: RegisteredModel | undefined): Promise<void> {
details: RegisteredModelDetails | undefined): Promise<void> {
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
throw Error(constants.invalidAzureResourceError);
}
@@ -188,4 +237,47 @@ export class ModelManagementController extends ControllerBase {
throw Error(constants.invalidModelToRegisterError);
}
}
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
return await predictService.getDatabaseList();
}
public async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
return await predictService.getTableList(databaseName);
}
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
return await predictService.getTableColumnsList(databaseTable);
}
private async generatePredictScript(
predictService: PredictService,
predictParams: PredictParameters,
registeredModel: RegisteredModel | undefined,
filePath: string | undefined
): Promise<string> {
if (!predictParams) {
throw Error(constants.invalidModelToPredictError);
}
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
return result;
}
private async downloadAzureModel(
azureService: AzureModelRegistryService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined): Promise<string> {
if (!account || !subscription || !resourceGroup || !workspace || !model) {
throw Error(constants.invalidAzureResourceError);
}
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
return filePath;
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
}