Machine Learning Services - Model detection in predict wizard (#9609)

* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
Leila Lali
2020-03-25 13:18:19 -07:00
committed by GitHub
parent 176edde2aa
commit ab82c04766
44 changed files with 2265 additions and 376 deletions

View File

@@ -9,15 +9,15 @@ import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { RegisterModelWizard } from './registerModels/registerModelWizard';
@@ -39,7 +39,7 @@ export class ModelManagementController extends ControllerBase {
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
private _registeredModelService: RegisteredModelService,
private _registeredModelService: DeployedModelService,
private _predictService: PredictService) {
super(apiWrapper);
}
@@ -61,7 +61,7 @@ export class ModelManagementController extends ControllerBase {
// Open view
//
view.open();
await view.open();
await view.refresh();
return view;
}
@@ -74,10 +74,15 @@ export class ModelManagementController extends ControllerBase {
let view = new PredictWizard(this._apiWrapper, this._root);
this.registerEvents(view);
view.on(LoadModelParametersEventName, async () => {
const modelArtifact = await view.getModelFileName();
await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
modelArtifact?.filePath);
});
// Open view
//
view.open();
await view.open();
await view.refresh();
return view;
}
@@ -151,6 +156,11 @@ export class ModelManagementController extends ControllerBase {
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
predictArgs, predictArgs.model, predictArgs.filePath);
});
view.on(DownloadRegisteredModelEventName, async (arg) => {
let model = <RegisteredModel>arg;
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
model);
});
view.on(SourceModelSelectedEventName, () => {
view.refresh();
});
@@ -191,8 +201,8 @@ export class ModelManagementController extends ControllerBase {
return await service.getWorkspaces(account, subscription, group);
}
private async getRegisteredModels(registeredModelService: RegisteredModelService): Promise<RegisteredModel[]> {
return registeredModelService.getRegisteredModels();
private async getRegisteredModels(registeredModelService: DeployedModelService): Promise<RegisteredModel[]> {
return registeredModelService.getDeployedModels();
}
private async getAzureModels(
@@ -207,9 +217,9 @@ export class ModelManagementController extends ControllerBase {
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
private async registerLocalModel(service: DeployedModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
if (filePath) {
await service.registerLocalModel(filePath, details);
await service.deployLocalModel(filePath, details);
} else {
throw Error(constants.invalidModelToRegisterError);
@@ -218,7 +228,7 @@ export class ModelManagementController extends ControllerBase {
private async registerAzureModel(
azureService: AzureModelRegistryService,
service: RegisteredModelService,
service: DeployedModelService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
@@ -231,7 +241,7 @@ export class ModelManagementController extends ControllerBase {
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
await service.registerLocalModel(filePath, details);
await service.deployLocalModel(filePath, details);
await fs.promises.unlink(filePath);
} else {
throw Error(constants.invalidModelToRegisterError);
@@ -246,7 +256,7 @@ export class ModelManagementController extends ControllerBase {
return await predictService.getTableList(databaseName);
}
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
return await predictService.getTableColumnsList(databaseTable);
}
@@ -263,6 +273,24 @@ export class ModelManagementController extends ControllerBase {
return result;
}
private async downloadRegisteredModel(
registeredModelService: DeployedModelService,
model: RegisteredModel | undefined): Promise<string> {
if (!model) {
throw Error(constants.invalidModelToPredictError);
}
return await registeredModelService.downloadModel(model);
}
private async loadModelParameters(
registeredModelService: DeployedModelService,
model: string | undefined): Promise<ModelParameters | undefined> {
if (!model) {
return undefined;
}
return await registeredModelService.loadModelParameters(model);
}
private async downloadAzureModel(
azureService: AzureModelRegistryService,
account: azdata.Account | undefined,