mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-08 17:24:01 -05:00
Machine Learning Services Extension - Predict wizard (#9450)
*MLS extension - Added predict wizard
This commit is contained in:
@@ -9,14 +9,23 @@ import { azureResource } from '../../typings/azure-resource';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
|
||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
|
||||
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
|
||||
import { RegisteredModelsDialog } from './registeredModelsDialog';
|
||||
import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName } from './modelViewBase';
|
||||
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
|
||||
import {
|
||||
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
|
||||
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
|
||||
} from './modelViewBase';
|
||||
import { ControllerBase } from '../controllerBase';
|
||||
import { RegisterModelWizard } from './registerModelWizard';
|
||||
import { RegisterModelWizard } from './registerModels/registerModelWizard';
|
||||
import * as fs from 'fs';
|
||||
import * as constants from '../../common/constants';
|
||||
import { PredictWizard } from './prediction/predictWizard';
|
||||
import { AzureModelResource } from '../interfaces';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
|
||||
/**
|
||||
* Model management UI controller
|
||||
@@ -30,7 +39,8 @@ export class ModelManagementController extends ControllerBase {
|
||||
apiWrapper: ApiWrapper,
|
||||
private _root: string,
|
||||
private _amlService: AzureModelRegistryService,
|
||||
private _registeredModelService: RegisteredModelService) {
|
||||
private _registeredModelService: RegisteredModelService,
|
||||
private _predictService: PredictService) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
|
||||
@@ -56,6 +66,23 @@ export class ModelManagementController extends ControllerBase {
|
||||
return view;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the wizard for prediction
|
||||
*/
|
||||
public async predictModel(): Promise<ModelViewBase> {
|
||||
|
||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||
|
||||
this.registerEvents(view);
|
||||
|
||||
// Open view
|
||||
//
|
||||
view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Register events in the main view
|
||||
* @param view main view
|
||||
@@ -102,6 +129,28 @@ export class ModelManagementController extends ControllerBase {
|
||||
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
|
||||
});
|
||||
view.on(DownloadAzureModelEventName, async (arg) => {
|
||||
let registerArgs = <AzureModelResource>arg;
|
||||
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, async () => {
|
||||
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
|
||||
});
|
||||
view.on(ListTableNamesEventName, async (arg) => {
|
||||
let dbName = <string>arg;
|
||||
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
|
||||
});
|
||||
view.on(ListColumnNamesEventName, async (arg) => {
|
||||
let tableColumnsArgs = <DatabaseTable>arg;
|
||||
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
|
||||
tableColumnsArgs);
|
||||
});
|
||||
view.on(PredictModelEventName, async (arg) => {
|
||||
let predictArgs = <PredictModelEventArgs>arg;
|
||||
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
|
||||
predictArgs, predictArgs.model, predictArgs.filePath);
|
||||
});
|
||||
view.on(SourceModelSelectedEventName, () => {
|
||||
view.refresh();
|
||||
});
|
||||
@@ -158,7 +207,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
||||
}
|
||||
|
||||
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise<void> {
|
||||
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
if (filePath) {
|
||||
await service.registerLocalModel(filePath, details);
|
||||
} else {
|
||||
@@ -175,7 +224,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
workspace: Workspace | undefined,
|
||||
model: WorkspaceModel | undefined,
|
||||
details: RegisteredModel | undefined): Promise<void> {
|
||||
details: RegisteredModelDetails | undefined): Promise<void> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
|
||||
throw Error(constants.invalidAzureResourceError);
|
||||
}
|
||||
@@ -188,4 +237,47 @@ export class ModelManagementController extends ControllerBase {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
|
||||
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
||||
return await predictService.getDatabaseList();
|
||||
}
|
||||
|
||||
public async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
|
||||
return await predictService.getTableList(databaseName);
|
||||
}
|
||||
|
||||
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
|
||||
return await predictService.getTableColumnsList(databaseTable);
|
||||
}
|
||||
|
||||
private async generatePredictScript(
|
||||
predictService: PredictService,
|
||||
predictParams: PredictParameters,
|
||||
registeredModel: RegisteredModel | undefined,
|
||||
filePath: string | undefined
|
||||
): Promise<string> {
|
||||
if (!predictParams) {
|
||||
throw Error(constants.invalidModelToPredictError);
|
||||
}
|
||||
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
|
||||
return result;
|
||||
}
|
||||
|
||||
private async downloadAzureModel(
|
||||
azureService: AzureModelRegistryService,
|
||||
account: azdata.Account | undefined,
|
||||
subscription: azureResource.AzureResourceSubscription | undefined,
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
workspace: Workspace | undefined,
|
||||
model: WorkspaceModel | undefined): Promise<string> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace || !model) {
|
||||
throw Error(constants.invalidAzureResourceError);
|
||||
}
|
||||
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
|
||||
if (filePath) {
|
||||
return filePath;
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user