mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
Machine Learning - Supporting multiple model import (#9869)
* Machine Learning Extension - Changed the deploy wizard to deploy multiple files
This commit is contained in:
@@ -117,6 +117,7 @@ export const extLangUpdateFailedError = localize('extLang.updateFailedError', "F
|
|||||||
|
|
||||||
export const modelArtifactName = localize('models.artifactName', "Artifact Name");
|
export const modelArtifactName = localize('models.artifactName', "Artifact Name");
|
||||||
export const modelName = localize('models.name', "Name");
|
export const modelName = localize('models.name', "Name");
|
||||||
|
export const modelFileName = localize('models.fileName', "File");
|
||||||
export const modelDescription = localize('models.description', "Description");
|
export const modelDescription = localize('models.description', "Description");
|
||||||
export const modelCreated = localize('models.created', "Date Created");
|
export const modelCreated = localize('models.created', "Date Created");
|
||||||
export const modelVersion = localize('models.version', "Version");
|
export const modelVersion = localize('models.version', "Version");
|
||||||
@@ -139,9 +140,9 @@ export const azureModels = localize('models.azureModels', "Models");
|
|||||||
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
|
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
|
||||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Enter model source details");
|
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?");
|
||||||
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
|
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
|
||||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
|
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details");
|
||||||
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
||||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||||
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
|
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
|
||||||
@@ -151,9 +152,9 @@ export const deployModelTitle = localize('models.deployModelTitle', "Deploy mode
|
|||||||
export const makePredictionTitle = localize('models.makePredictionTitle', "Make prediction");
|
export const makePredictionTitle = localize('models.makePredictionTitle', "Make prediction");
|
||||||
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
|
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
|
||||||
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
|
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
|
||||||
export const localModelSource = localize('models.localModelSource', "Upload file");
|
export const localModelSource = localize('models.localModelSource', "File upload");
|
||||||
export const azureModelSource = localize('models.azureModelSource', "Import from AzureML registry");
|
export const azureModelSource = localize('models.azureModelSource', "Azure Machine Learning");
|
||||||
export const registeredModelsSource = localize('models.registeredModelsSource', "Select managed models");
|
export const registeredModelsSource = localize('models.registeredModelsSource', "Imported models");
|
||||||
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
|
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
|
||||||
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
|
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
|
||||||
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
||||||
@@ -161,7 +162,8 @@ export const invalidModelToPredictError = localize('models.invalidModelToPredict
|
|||||||
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
||||||
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||||
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");
|
export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); }
|
||||||
|
|
||||||
export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
|
export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
|
||||||
export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
|
export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
|
||||||
|
|
||||||
|
|||||||
@@ -248,3 +248,15 @@ export async function writeFileFromHex(content: string): Promise<string> {
|
|||||||
await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex'));
|
await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex'));
|
||||||
return tempFilePath;
|
return tempFilePath;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param filePath Returns file name
|
||||||
|
*/
|
||||||
|
export function getFileName(filePath: string) {
|
||||||
|
if (filePath) {
|
||||||
|
return filePath.replace(/^.*[\\\/]/, '');
|
||||||
|
} else {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -86,22 +86,23 @@ export class DeployedModelService {
|
|||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
if (connection) {
|
if (connection) {
|
||||||
let currentModels = await this.getDeployedModels();
|
let currentModels = await this.getDeployedModels();
|
||||||
await this._modelClient.deployModel(connection, filePath);
|
const content = await utils.readFileInHex(filePath);
|
||||||
let updatedModels = await this.getDeployedModels();
|
const fileName = details?.fileName || utils.getFileName(filePath);
|
||||||
if (details && updatedModels.length >= currentModels.length + 1) {
|
let modelToAdd: RegisteredModel = {
|
||||||
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
|
id: 0,
|
||||||
const addedModel = updatedModels[updatedModels.length - 1];
|
artifactName: fileName,
|
||||||
addedModel.title = details.title;
|
content: content,
|
||||||
addedModel.description = details.description;
|
title: details?.title || fileName,
|
||||||
addedModel.version = details.version;
|
description: details?.description,
|
||||||
const updatedModel = await this.updateModel(addedModel);
|
version: details?.version
|
||||||
if (!updatedModel) {
|
};
|
||||||
throw Error(constants.updateModelFailedError);
|
await this._queryRunner.safeRunQuery(connection, this.getInsertModelQuery(connection.databaseName, modelToAdd));
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
let updatedModels = await this.getDeployedModels();
|
||||||
throw Error(constants.importModelFailedError);
|
if (updatedModels.length < currentModels.length + 1) {
|
||||||
|
throw Error(constants.importModelFailedError(details?.title, filePath));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||||
@@ -115,20 +116,6 @@ export class DeployedModelService {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
|
||||||
let connection = await this.getCurrentConnection();
|
|
||||||
let updatedModel: RegisteredModel | undefined = undefined;
|
|
||||||
if (connection) {
|
|
||||||
const query = this.getUpdateModelQuery(connection.databaseName, model);
|
|
||||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
|
||||||
if (result?.rows && result.rows.length > 0) {
|
|
||||||
const row = result.rows[0];
|
|
||||||
updatedModel = this.loadModelData(row);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return updatedModel;
|
|
||||||
}
|
|
||||||
|
|
||||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||||
return await this._apiWrapper.getCurrentConnection();
|
return await this._apiWrapper.getCurrentConnection();
|
||||||
}
|
}
|
||||||
@@ -190,7 +177,7 @@ export class DeployedModelService {
|
|||||||
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
|
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
|
||||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||||
[artifact_name] [varchar](256) NOT NULL,
|
[artifact_name] [varchar](256) NOT NULL,
|
||||||
[group_path] [varchar](256) NOT NULL,
|
[group_path] [varchar](256) NULL,
|
||||||
[artifact_content] [varbinary](max) NOT NULL,
|
[artifact_content] [varbinary](max) NOT NULL,
|
||||||
[artifact_initial_size] [bigint] NULL,
|
[artifact_initial_size] [bigint] NULL,
|
||||||
[name] [varchar](256) NULL,
|
[name] [varchar](256) NULL,
|
||||||
@@ -207,20 +194,24 @@ export class DeployedModelService {
|
|||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|
||||||
public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
|
public getInsertModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
|
||||||
let updateScript = `
|
let updateScript = `
|
||||||
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
|
Insert into ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||||
SET
|
(artifact_name, group_path, artifact_content, name, version, description)
|
||||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
values (
|
||||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
'${utils.doubleEscapeSingleQuotes(model.artifactName || '')}',
|
||||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
'ADS',
|
||||||
WHERE artifact_id = ${model.id}`;
|
${utils.doubleEscapeSingleQuotes(model.content || '')},
|
||||||
|
'${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||||
|
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||||
|
'${utils.doubleEscapeSingleQuotes(model.description || '')}')
|
||||||
|
`;
|
||||||
|
|
||||||
return `
|
return `
|
||||||
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
|
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
|
||||||
SELECT artifact_id, artifact_name, name, description, version, created
|
SELECT artifact_id, artifact_name, name, description, version, created
|
||||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||||
WHERE artifact_id = ${model.id};
|
WHERE artifact_id = SCOPE_IDENTITY();
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
|||||||
export interface RegisteredModel extends RegisteredModelDetails {
|
export interface RegisteredModel extends RegisteredModelDetails {
|
||||||
id: number;
|
id: number;
|
||||||
artifactName: string;
|
artifactName: string;
|
||||||
|
content?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ModelParameter {
|
export interface ModelParameter {
|
||||||
@@ -71,6 +72,7 @@ export interface RegisteredModelDetails {
|
|||||||
created?: string;
|
created?: string;
|
||||||
version?: string;
|
version?: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
|
fileName?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -230,3 +232,4 @@ export interface ArtifactAPIGetArtifactContentInformation2OptionalParams extends
|
|||||||
*/
|
*/
|
||||||
accountName?: string;
|
accountName?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -248,9 +248,10 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object);
|
||||||
testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => {
|
|
||||||
|
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => {
|
||||||
deployed = true;
|
deployed = true;
|
||||||
return Promise.resolve();
|
return Promise.resolve(result);
|
||||||
});
|
});
|
||||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||||
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
||||||
@@ -259,7 +260,15 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||||
await should(service.deployLocalModel('', model)).resolved();
|
let tempFilePath: string = '';
|
||||||
|
try {
|
||||||
|
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||||
|
await fs.promises.writeFile(tempFilePath, 'test');
|
||||||
|
await should(service.deployLocalModel(tempFilePath, model)).resolved();
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
await utils.deleteFile(tempFilePath);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
||||||
@@ -306,7 +315,7 @@ describe('DeployedModelService', () => {
|
|||||||
CREATE TABLE [dbo].[ta[[b]]le](
|
CREATE TABLE [dbo].[ta[[b]]le](
|
||||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||||
[artifact_name] [varchar](256) NOT NULL,
|
[artifact_name] [varchar](256) NOT NULL,
|
||||||
[group_path] [varchar](256) NOT NULL,
|
[group_path] [varchar](256) NULL,
|
||||||
[artifact_content] [varbinary](max) NOT NULL,
|
[artifact_content] [varbinary](max) NOT NULL,
|
||||||
[artifact_initial_size] [bigint] NULL,
|
[artifact_initial_size] [bigint] NULL,
|
||||||
[name] [varchar](256) NULL,
|
[name] [varchar](256) NULL,
|
||||||
@@ -345,7 +354,7 @@ describe('DeployedModelService', () => {
|
|||||||
should.deepEqual(expected, actual);
|
should.deepEqual(expected, actual);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('getUpdateModelQuery should escape db name', async function (): Promise<void> {
|
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
|
||||||
const testContext = createContext();
|
const testContext = createContext();
|
||||||
const dbName = 'curre[n]tDb';
|
const dbName = 'curre[n]tDb';
|
||||||
const model: RegisteredModel =
|
const model: RegisteredModel =
|
||||||
@@ -367,16 +376,17 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||||
const expected = `
|
const expected = `
|
||||||
UPDATE [dbo].[ta[[b]]le]
|
Insert into [dbo].[ta[[b]]le]
|
||||||
SET
|
(artifact_name, group_path, artifact_content, name, version, description)
|
||||||
name = 'title1',
|
values (
|
||||||
version = '1.1',
|
'name1',
|
||||||
description = 'desc1'
|
'ADS',
|
||||||
WHERE artifact_id = 1`;
|
,
|
||||||
const actual = service.getUpdateModelQuery(dbName, model);
|
'title1',
|
||||||
|
'1.1',
|
||||||
|
'desc1')`;
|
||||||
|
const actual = service.getInsertModelQuery(dbName, model);
|
||||||
should.equal(actual.indexOf(expected) > 0, true);
|
should.equal(actual.indexOf(expected) > 0, true);
|
||||||
//should.deepEqual(actual, expected);
|
|
||||||
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('getModelContentQuery should escape db name', async function (): Promise<void> {
|
it('getModelContentQuery should escape db name', async function (): Promise<void> {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ describe('Azure Models Component', () => {
|
|||||||
let testContext = createContext();
|
let testContext = createContext();
|
||||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||||
|
|
||||||
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent);
|
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent, false);
|
||||||
view.registerComponent(testContext.view.modelBuilder);
|
view.registerComponent(testContext.view.modelBuilder);
|
||||||
|
|
||||||
let accounts: azdata.Account[] = [
|
let accounts: azdata.Account[] = [
|
||||||
@@ -88,12 +88,15 @@ describe('Azure Models Component', () => {
|
|||||||
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||||
});
|
});
|
||||||
await view.refresh();
|
await view.refresh();
|
||||||
testContext.onClick.fire();
|
testContext.onClick.fire(true);
|
||||||
should.notEqual(view.data, undefined);
|
should.notEqual(view.data, undefined);
|
||||||
should.deepEqual(view.data?.account, accounts[0]);
|
should.equal(view.data?.length, 1);
|
||||||
should.deepEqual(view.data?.subscription, subscriptions[0]);
|
if (view.data) {
|
||||||
should.deepEqual(view.data?.group, groups[0]);
|
should.deepEqual(view.data[0].account, accounts[0]);
|
||||||
should.deepEqual(view.data?.workspace, workspaces[0]);
|
should.deepEqual(view.data[0].subscription, subscriptions[0]);
|
||||||
should.deepEqual(view.data?.model, models[0]);
|
should.deepEqual(view.data[0].group, groups[0]);
|
||||||
|
should.deepEqual(view.data[0].workspace, workspaces[0]);
|
||||||
|
should.deepEqual(view.data[0].model, models[0]);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import 'mocha';
|
|||||||
import { createContext } from './utils';
|
import { createContext } from './utils';
|
||||||
import {
|
import {
|
||||||
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
|
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
|
||||||
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName
|
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType
|
||||||
}
|
}
|
||||||
from '../../../views/models/modelViewBase';
|
from '../../../views/models/modelViewBase';
|
||||||
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
|
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||||
@@ -164,9 +164,25 @@ describe('Predict Wizard', () => {
|
|||||||
view.on(DownloadRegisteredModelEventName, () => {
|
view.on(DownloadRegisteredModelEventName, () => {
|
||||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
|
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
|
||||||
});
|
});
|
||||||
|
if (view.modelBrowsePage) {
|
||||||
|
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
|
||||||
|
}
|
||||||
await view.refresh();
|
await view.refresh();
|
||||||
should.notEqual(view.azureModelsComponent?.data, undefined);
|
should.notEqual(view.azureModelsComponent?.data, undefined);
|
||||||
|
|
||||||
|
if (view.modelBrowsePage) {
|
||||||
|
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
|
||||||
|
}
|
||||||
|
await view.refresh();
|
||||||
|
testContext.onClick.fire();
|
||||||
|
|
||||||
|
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels);
|
||||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||||
|
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined);
|
||||||
|
if (view.modelBrowsePage?.registeredModelsComponent?.data) {
|
||||||
|
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
should.notEqual(await view.getModelFileName(), undefined);
|
should.notEqual(await view.getModelFileName(), undefined);
|
||||||
await view.columnsSelectionPage?.onEnter();
|
await view.columnsSelectionPage?.onEnter();
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import * as azdata from 'azdata';
|
|||||||
import * as should from 'should';
|
import * as should from 'should';
|
||||||
import 'mocha';
|
import 'mocha';
|
||||||
import { createContext } from './utils';
|
import { createContext } from './utils';
|
||||||
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName } from '../../../views/models/modelViewBase';
|
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType } from '../../../views/models/modelViewBase';
|
||||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
import { azureResource } from '../../../typings/azure-resource';
|
import { azureResource } from '../../../typings/azure-resource';
|
||||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
@@ -97,6 +97,10 @@ describe('Register Model Wizard', () => {
|
|||||||
view.on(ListAzureModelsEventName, () => {
|
view.on(ListAzureModelsEventName, () => {
|
||||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (view.modelBrowsePage) {
|
||||||
|
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
|
||||||
|
}
|
||||||
await view.refresh();
|
await view.refresh();
|
||||||
should.notEqual(view.azureModelsComponent?.data ,undefined);
|
should.notEqual(view.azureModelsComponent?.data ,undefined);
|
||||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||||
|
|||||||
@@ -32,8 +32,13 @@ export function createViewContext(): ViewTestContext {
|
|||||||
onDidClick: onClick.event
|
onDidClick: onClick.event
|
||||||
});
|
});
|
||||||
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
|
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
|
||||||
|
checked: true,
|
||||||
onDidClick: onClick.event
|
onDidClick: onClick.event
|
||||||
});
|
});
|
||||||
|
let checkbox: azdata.CheckBoxComponent = Object.assign({}, componentBase, {
|
||||||
|
checked: true,
|
||||||
|
onChanged: onClick.event
|
||||||
|
});
|
||||||
let container = {
|
let container = {
|
||||||
clearItems: () => { },
|
clearItems: () => { },
|
||||||
addItems: () => { },
|
addItems: () => { },
|
||||||
@@ -58,6 +63,11 @@ export function createViewContext(): ViewTestContext {
|
|||||||
withProperties: () => radioButtonBuilder,
|
withProperties: () => radioButtonBuilder,
|
||||||
withValidation: () => radioButtonBuilder
|
withValidation: () => radioButtonBuilder
|
||||||
};
|
};
|
||||||
|
let checkBoxBuilder: azdata.ComponentBuilder<azdata.CheckBoxComponent> = {
|
||||||
|
component: () => checkbox,
|
||||||
|
withProperties: () => checkBoxBuilder,
|
||||||
|
withValidation: () => checkBoxBuilder
|
||||||
|
};
|
||||||
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
|
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
|
||||||
onTextChanged: undefined!,
|
onTextChanged: undefined!,
|
||||||
onEnterKeyPressed: undefined!,
|
onEnterKeyPressed: undefined!,
|
||||||
@@ -85,6 +95,12 @@ export function createViewContext(): ViewTestContext {
|
|||||||
component: undefined!
|
component: undefined!
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let card: () => azdata.CardComponent = () => Object.assign({}, componentBase, {
|
||||||
|
label: '',
|
||||||
|
onDidActionClick: new vscode.EventEmitter<azdata.ActionDescriptor>().event,
|
||||||
|
onCardSelectedChanged: onClick.event
|
||||||
|
});
|
||||||
|
|
||||||
let declarativeTableBuilder: azdata.ComponentBuilder<azdata.DeclarativeTableComponent> = {
|
let declarativeTableBuilder: azdata.ComponentBuilder<azdata.DeclarativeTableComponent> = {
|
||||||
component: () => declarativeTable(),
|
component: () => declarativeTable(),
|
||||||
withProperties: () => declarativeTableBuilder,
|
withProperties: () => declarativeTableBuilder,
|
||||||
@@ -127,6 +143,15 @@ export function createViewContext(): ViewTestContext {
|
|||||||
withProperties: () => inputBoxBuilder,
|
withProperties: () => inputBoxBuilder,
|
||||||
withValidation: () => inputBoxBuilder
|
withValidation: () => inputBoxBuilder
|
||||||
};
|
};
|
||||||
|
let cardBuilder: azdata.ComponentBuilder<azdata.CardComponent> = {
|
||||||
|
component: () => {
|
||||||
|
let r = card();
|
||||||
|
return r;
|
||||||
|
},
|
||||||
|
withProperties: () => cardBuilder,
|
||||||
|
withValidation: () => cardBuilder
|
||||||
|
};
|
||||||
|
|
||||||
let imageBuilder: azdata.ComponentBuilder<azdata.ImageComponent> = {
|
let imageBuilder: azdata.ComponentBuilder<azdata.ImageComponent> = {
|
||||||
component: () => {
|
component: () => {
|
||||||
let r = image();
|
let r = image();
|
||||||
@@ -159,9 +184,9 @@ export function createViewContext(): ViewTestContext {
|
|||||||
flexContainer: () => flexBuilder,
|
flexContainer: () => flexBuilder,
|
||||||
splitViewContainer: undefined!,
|
splitViewContainer: undefined!,
|
||||||
dom: undefined!,
|
dom: undefined!,
|
||||||
card: undefined!,
|
card: () => cardBuilder,
|
||||||
inputBox: () => inputBoxBuilder,
|
inputBox: () => inputBoxBuilder,
|
||||||
checkBox: undefined!,
|
checkBox: () => checkBoxBuilder!,
|
||||||
radioButton: () => radioButtonBuilder,
|
radioButton: () => radioButtonBuilder,
|
||||||
webView: undefined!,
|
webView: undefined!,
|
||||||
editor: undefined!,
|
editor: undefined!,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import { AzureModelsTable } from './azureModelsTable';
|
|||||||
import { IDataComponent, AzureModelResource } from '../interfaces';
|
import { IDataComponent, AzureModelResource } from '../interfaces';
|
||||||
import { ModelArtifact } from './prediction/modelArtifact';
|
import { ModelArtifact } from './prediction/modelArtifact';
|
||||||
|
|
||||||
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource> {
|
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource[]> {
|
||||||
|
|
||||||
public azureModelsTable: AzureModelsTable | undefined;
|
public azureModelsTable: AzureModelsTable | undefined;
|
||||||
public azureFilterComponent: AzureResourceFilterComponent | undefined;
|
public azureFilterComponent: AzureResourceFilterComponent | undefined;
|
||||||
@@ -23,7 +23,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
|
|||||||
/**
|
/**
|
||||||
* Component to render a view to pick an azure model
|
* Component to render a view to pick an azure model
|
||||||
*/
|
*/
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,7 +33,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
|
|||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this);
|
this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this);
|
||||||
this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this);
|
this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect);
|
||||||
this._loader = modelBuilder.loadingComponent()
|
this._loader = modelBuilder.loadingComponent()
|
||||||
.withItem(this.azureModelsTable.component)
|
.withItem(this.azureModelsTable.component)
|
||||||
.withProperties({
|
.withProperties({
|
||||||
@@ -109,15 +109,16 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
|
|||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): AzureModelResource | undefined {
|
public get data(): AzureModelResource[] | undefined {
|
||||||
return Object.assign({}, this.azureFilterComponent?.data, {
|
return this.azureModelsTable?.data ? this.azureModelsTable?.data.map(x => Object.assign({}, this.azureFilterComponent?.data, {
|
||||||
model: this.azureModelsTable?.data
|
model: x
|
||||||
});
|
})) : undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getDownloadedModel(): Promise<ModelArtifact> {
|
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
|
||||||
if (!this._downloadedFile) {
|
const data = this.data;
|
||||||
this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(this.data));
|
if (!this._downloadedFile && data && data.length > 0) {
|
||||||
|
this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(data[0]));
|
||||||
}
|
}
|
||||||
return this._downloadedFile;
|
return this._downloadedFile;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ import { IDataComponent, AzureWorkspaceResource } from '../interfaces';
|
|||||||
/**
|
/**
|
||||||
* View to render azure models in a table
|
* View to render azure models in a table
|
||||||
*/
|
*/
|
||||||
export class AzureModelsTable extends ModelViewBase implements IDataComponent<WorkspaceModel> {
|
export class AzureModelsTable extends ModelViewBase implements IDataComponent<WorkspaceModel[]> {
|
||||||
|
|
||||||
private _table: azdata.DeclarativeTableComponent;
|
private _table: azdata.DeclarativeTableComponent;
|
||||||
private _selectedModelId: any;
|
private _selectedModel: WorkspaceModel[] = [];
|
||||||
private _models: WorkspaceModel[] | undefined;
|
private _models: WorkspaceModel[] | undefined;
|
||||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||||
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
||||||
@@ -25,7 +25,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
|||||||
/**
|
/**
|
||||||
* Creates a view to render azure models in a table
|
* Creates a view to render azure models in a table
|
||||||
*/
|
*/
|
||||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
this._table = this.registerComponent(this._modelBuilder);
|
this._table = this.registerComponent(this._modelBuilder);
|
||||||
}
|
}
|
||||||
@@ -123,17 +123,42 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
|||||||
|
|
||||||
private createTableRow(model: WorkspaceModel): any[] {
|
private createTableRow(model: WorkspaceModel): any[] {
|
||||||
if (this._modelBuilder) {
|
if (this._modelBuilder) {
|
||||||
let selectModelButton = this._modelBuilder.radioButton().withProperties({
|
let selectModelButton: azdata.Component;
|
||||||
name: 'amlModel',
|
let onSelectItem = (checked: boolean) => {
|
||||||
value: model.id,
|
const foundItem = this._selectedModel.find(x => x === model);
|
||||||
width: 15,
|
if (checked && !foundItem) {
|
||||||
height: 15,
|
this._selectedModel.push(model);
|
||||||
checked: false
|
} else if (foundItem) {
|
||||||
}).component();
|
this._selectedModel = this._selectedModel.filter(x => x !== model);
|
||||||
selectModelButton.onDidClick(() => {
|
}
|
||||||
this._selectedModelId = model.id;
|
|
||||||
this._onModelSelectionChanged.fire();
|
this._onModelSelectionChanged.fire();
|
||||||
});
|
};
|
||||||
|
if (this._multiSelect) {
|
||||||
|
const checkbox = this._modelBuilder.checkBox().withProperties({
|
||||||
|
name: 'amlModel',
|
||||||
|
value: model.id,
|
||||||
|
width: 15,
|
||||||
|
height: 15,
|
||||||
|
checked: false
|
||||||
|
}).component();
|
||||||
|
checkbox.onChanged(() => {
|
||||||
|
onSelectItem(checkbox.checked || false);
|
||||||
|
});
|
||||||
|
selectModelButton = checkbox;
|
||||||
|
} else {
|
||||||
|
const radioButton = this._modelBuilder.radioButton().withProperties({
|
||||||
|
name: 'amlModel',
|
||||||
|
value: model.id,
|
||||||
|
width: 15,
|
||||||
|
height: 15,
|
||||||
|
checked: false
|
||||||
|
}).component();
|
||||||
|
radioButton.onDidClick(() => {
|
||||||
|
onSelectItem(radioButton.checked || false);
|
||||||
|
});
|
||||||
|
selectModelButton = radioButton;
|
||||||
|
}
|
||||||
|
|
||||||
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
|
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,9 +168,9 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
|||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): WorkspaceModel | undefined {
|
public get data(): WorkspaceModel[] | undefined {
|
||||||
if (this._models && this._selectedModelId) {
|
if (this._models && this._selectedModel) {
|
||||||
return this._models.find(x => x.id === this._selectedModelId);
|
return this._selectedModel;
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import { IDataComponent } from '../interfaces';
|
|||||||
/**
|
/**
|
||||||
* View to pick local models file
|
* View to pick local models file
|
||||||
*/
|
*/
|
||||||
export class LocalModelsComponent extends ModelViewBase implements IDataComponent<string> {
|
export class LocalModelsComponent extends ModelViewBase implements IDataComponent<string[]> {
|
||||||
|
|
||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
private _flex: azdata.FlexContainer | undefined;
|
private _flex: azdata.FlexContainer | undefined;
|
||||||
@@ -24,7 +24,7 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen
|
|||||||
/**
|
/**
|
||||||
* Creates new view
|
* Creates new view
|
||||||
*/
|
*/
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,13 +49,15 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen
|
|||||||
let options: vscode.OpenDialogOptions = {
|
let options: vscode.OpenDialogOptions = {
|
||||||
canSelectFiles: true,
|
canSelectFiles: true,
|
||||||
canSelectFolders: false,
|
canSelectFolders: false,
|
||||||
canSelectMany: false,
|
canSelectMany: this._multiSelect,
|
||||||
filters: { 'ONNX File': ['onnx'] }
|
filters: { 'ONNX File': ['onnx'] }
|
||||||
};
|
};
|
||||||
|
|
||||||
const filePaths = await this.getLocalPaths(options);
|
const filePaths = await this.getLocalPaths(options);
|
||||||
if (this._localPath) {
|
if (this._localPath && filePaths && filePaths.length > 0) {
|
||||||
this._localPath.value = filePaths && filePaths.length > 0 ? filePaths[0] : '';
|
this._localPath.value = this._multiSelect ? filePaths.join(';') : filePaths[0];
|
||||||
|
} else if (this._localPath) {
|
||||||
|
this._localPath.value = '';
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -96,8 +98,12 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen
|
|||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): string {
|
public get data(): string[] {
|
||||||
return this._localPath?.value || '';
|
if (this._localPath?.value) {
|
||||||
|
return this._localPath?.value.split(';');
|
||||||
|
} else {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -0,0 +1,178 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import { ModelViewBase, ModelSourceType, ModelViewData } from './modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
|
import * as constants from '../../common/constants';
|
||||||
|
import { IPageView, IDataComponent } from '../interfaces';
|
||||||
|
import { LocalModelsComponent } from './localModelsComponent';
|
||||||
|
import { AzureModelsComponent } from './azureModelsComponent';
|
||||||
|
import { CurrentModelsTable } from './registerModels/currentModelsTable';
|
||||||
|
import * as utils from '../../common/utils';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View to pick model source
|
||||||
|
*/
|
||||||
|
export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
|
||||||
|
|
||||||
|
private _form: azdata.FormContainer | undefined;
|
||||||
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
|
public localModelsComponent: LocalModelsComponent | undefined;
|
||||||
|
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||||
|
public registeredModelsComponent: CurrentModelsTable | undefined;
|
||||||
|
|
||||||
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||||
|
super(apiWrapper, parent.root, parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param modelBuilder Register components
|
||||||
|
*/
|
||||||
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
|
||||||
|
this._formBuilder = modelBuilder.formContainer();
|
||||||
|
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this, this._multiSelect);
|
||||||
|
this.localModelsComponent.registerComponent(modelBuilder);
|
||||||
|
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect);
|
||||||
|
this.azureModelsComponent.registerComponent(modelBuilder);
|
||||||
|
this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect);
|
||||||
|
this.registeredModelsComponent.registerComponent(modelBuilder);
|
||||||
|
this.refresh();
|
||||||
|
this._form = this._formBuilder.component();
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): ModelViewData[] {
|
||||||
|
return this.modelsViewData;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the component
|
||||||
|
*/
|
||||||
|
public get component(): azdata.Component | undefined {
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refreshes the view
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
if (this._formBuilder) {
|
||||||
|
if (this.modelSourceType === ModelSourceType.Local) {
|
||||||
|
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||||
|
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.localModelsComponent.addComponents(this._formBuilder);
|
||||||
|
await this.localModelsComponent.refresh();
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.Azure) {
|
||||||
|
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||||
|
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.azureModelsComponent.addComponents(this._formBuilder);
|
||||||
|
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
await this.azureModelsComponent.refresh();
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
|
||||||
|
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||||
|
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.registeredModelsComponent.addComponents(this._formBuilder);
|
||||||
|
await this.registeredModelsComponent.refresh();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns page title
|
||||||
|
*/
|
||||||
|
public get title(): string {
|
||||||
|
return constants.modelSourcePageTitle;
|
||||||
|
}
|
||||||
|
|
||||||
|
public validate(): Promise<boolean> {
|
||||||
|
let validated = false;
|
||||||
|
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
|
||||||
|
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
|
||||||
|
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
|
||||||
|
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0;
|
||||||
|
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
|
||||||
|
validated = this.registeredModelsComponent.data !== undefined && this.registeredModelsComponent.data.length > 0;
|
||||||
|
}
|
||||||
|
if (!validated) {
|
||||||
|
this.showErrorMessage(constants.invalidModelToSelectError);
|
||||||
|
}
|
||||||
|
return Promise.resolve(validated);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async onLeave(): Promise<void> {
|
||||||
|
this.modelsViewData = [];
|
||||||
|
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
|
||||||
|
if (this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0) {
|
||||||
|
this.modelsViewData = this.localModelsComponent.data.map(x => {
|
||||||
|
const fileName = utils.getFileName(x);
|
||||||
|
return {
|
||||||
|
modelData: x,
|
||||||
|
modelDetails: {
|
||||||
|
title: fileName,
|
||||||
|
fileName: fileName
|
||||||
|
}
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
|
||||||
|
if (this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0) {
|
||||||
|
this.modelsViewData = this.azureModelsComponent.data.map(x => {
|
||||||
|
return {
|
||||||
|
modelData: {
|
||||||
|
account: x.account,
|
||||||
|
subscription: x.subscription,
|
||||||
|
group: x.group,
|
||||||
|
workspace: x.workspace,
|
||||||
|
model: x.model
|
||||||
|
},
|
||||||
|
modelDetails: {
|
||||||
|
title: x.model?.name || '',
|
||||||
|
fileName: x.model?.name
|
||||||
|
}
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
|
||||||
|
if (this.registeredModelsComponent.data !== undefined) {
|
||||||
|
this.modelsViewData = this.registeredModelsComponent.data.map(x => {
|
||||||
|
return {
|
||||||
|
modelData: x,
|
||||||
|
modelDetails: {
|
||||||
|
title: ''
|
||||||
|
}
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async disposePage(): Promise<void> {
|
||||||
|
if (this.azureModelsComponent) {
|
||||||
|
await this.azureModelsComponent.disposeComponent();
|
||||||
|
|
||||||
|
}
|
||||||
|
if (this.registeredModelsComponent) {
|
||||||
|
await this.registeredModelsComponent.disposeComponent();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,25 +4,21 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import { ModelViewBase } from './modelViewBase';
|
import { ModelViewBase, ModelViewData } from './modelViewBase';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
import { IDataComponent } from '../interfaces';
|
import { IDataComponent } from '../interfaces';
|
||||||
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to pick local models file
|
* View to pick local models file
|
||||||
*/
|
*/
|
||||||
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModelDetails> {
|
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<ModelViewData[]> {
|
||||||
|
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||||
private _form: azdata.FormContainer | undefined;
|
|
||||||
private _nameComponent: azdata.InputBoxComponent | undefined;
|
|
||||||
private _descriptionComponent: azdata.InputBoxComponent | undefined;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates new view
|
* Creates new view
|
||||||
*/
|
*/
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,73 +27,162 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone
|
|||||||
* @param modelBuilder Register the components
|
* @param modelBuilder Register the components
|
||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
this._nameComponent = modelBuilder.inputBox().withProperties({
|
this._table = modelBuilder.declarativeTable()
|
||||||
value: '',
|
.withProperties<azdata.DeclarativeTableProperties>(
|
||||||
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength
|
{
|
||||||
}).component();
|
columns: [
|
||||||
this._descriptionComponent = modelBuilder.inputBox().withProperties({
|
{ // Name
|
||||||
value: '',
|
displayName: constants.modelFileName,
|
||||||
multiline: true,
|
ariaLabel: constants.modelFileName,
|
||||||
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength,
|
valueType: azdata.DeclarativeDataType.string,
|
||||||
hight: '50px'
|
isReadOnly: true,
|
||||||
}).component();
|
width: 150,
|
||||||
|
headerCssStyles: {
|
||||||
|
...constants.cssStyles.tableHeader
|
||||||
|
},
|
||||||
|
rowCssStyles: {
|
||||||
|
...constants.cssStyles.tableRow
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ // Name
|
||||||
|
displayName: constants.modelName,
|
||||||
|
ariaLabel: constants.modelName,
|
||||||
|
valueType: azdata.DeclarativeDataType.component,
|
||||||
|
isReadOnly: true,
|
||||||
|
width: 150,
|
||||||
|
headerCssStyles: {
|
||||||
|
...constants.cssStyles.tableHeader
|
||||||
|
},
|
||||||
|
rowCssStyles: {
|
||||||
|
...constants.cssStyles.tableRow
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ // Created
|
||||||
|
displayName: constants.modelDescription,
|
||||||
|
ariaLabel: constants.modelDescription,
|
||||||
|
valueType: azdata.DeclarativeDataType.component,
|
||||||
|
isReadOnly: true,
|
||||||
|
width: 100,
|
||||||
|
headerCssStyles: {
|
||||||
|
...constants.cssStyles.tableHeader
|
||||||
|
},
|
||||||
|
rowCssStyles: {
|
||||||
|
...constants.cssStyles.tableRow
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ // Action
|
||||||
|
displayName: '',
|
||||||
|
valueType: azdata.DeclarativeDataType.component,
|
||||||
|
isReadOnly: true,
|
||||||
|
width: 50,
|
||||||
|
headerCssStyles: {
|
||||||
|
...constants.cssStyles.tableHeader
|
||||||
|
},
|
||||||
|
rowCssStyles: {
|
||||||
|
...constants.cssStyles.tableRow
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
data: [],
|
||||||
|
ariaLabel: constants.mlsConfigTitle
|
||||||
|
})
|
||||||
|
.component();
|
||||||
|
|
||||||
this._form = modelBuilder.formContainer().withFormItems([{
|
return this._table;
|
||||||
title: constants.modelName,
|
|
||||||
component: this._nameComponent
|
|
||||||
}, {
|
|
||||||
title: constants.modelDescription,
|
|
||||||
component: this._descriptionComponent
|
|
||||||
}]).component();
|
|
||||||
return this._form;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._nameComponent && this._descriptionComponent) {
|
if (this._table) {
|
||||||
formBuilder.addFormItems([{
|
formBuilder.addFormItems([{
|
||||||
title: constants.modelName,
|
title: '',
|
||||||
component: this._nameComponent
|
component: this._table
|
||||||
}, {
|
|
||||||
title: constants.modelDescription,
|
|
||||||
component: this._descriptionComponent
|
|
||||||
}]);
|
}]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._nameComponent && this._descriptionComponent) {
|
if (this._table) {
|
||||||
formBuilder.removeFormItem({
|
formBuilder.removeFormItem({
|
||||||
title: constants.modelName,
|
title: '',
|
||||||
component: this._nameComponent
|
component: this._table
|
||||||
});
|
|
||||||
formBuilder.removeFormItem({
|
|
||||||
title: constants.modelDescription,
|
|
||||||
component: this._descriptionComponent
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load data in the component
|
||||||
|
* @param workspaceResource Azure workspace
|
||||||
|
*/
|
||||||
|
public async loadData(): Promise<void> {
|
||||||
|
|
||||||
|
const models = this.modelsViewData;
|
||||||
|
if (this._table && models) {
|
||||||
|
|
||||||
|
let tableData: any[][] = [];
|
||||||
|
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
|
||||||
|
this._table.data = tableData;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private createTableRow(model: ModelViewData | undefined): any[] {
|
||||||
|
if (this._modelBuilder && model && model.modelDetails) {
|
||||||
|
const nameComponent = this._modelBuilder.inputBox().withProperties({
|
||||||
|
value: model.modelDetails.title,
|
||||||
|
width: this.componentMaxLength - 100,
|
||||||
|
required: true
|
||||||
|
}).component();
|
||||||
|
const descriptionComponent = this._modelBuilder.inputBox().withProperties({
|
||||||
|
value: model.modelDetails.description,
|
||||||
|
width: this.componentMaxLength
|
||||||
|
}).component();
|
||||||
|
descriptionComponent.onTextChanged(() => {
|
||||||
|
if (model.modelDetails) {
|
||||||
|
model.modelDetails.description = descriptionComponent.value;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
nameComponent.onTextChanged(() => {
|
||||||
|
if (model.modelDetails) {
|
||||||
|
model.modelDetails.title = nameComponent.value || '';
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let deleteButton = this._modelBuilder.button().withProperties({
|
||||||
|
label: '',
|
||||||
|
title: constants.deleteTitle,
|
||||||
|
width: 15,
|
||||||
|
height: 15,
|
||||||
|
iconPath: {
|
||||||
|
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
|
||||||
|
light: this.asAbsolutePath('images/light/delete.svg')
|
||||||
|
},
|
||||||
|
}).component();
|
||||||
|
deleteButton.onDidClick(async () => {
|
||||||
|
this.modelsViewData = this.modelsViewData.filter(x => x !== model);
|
||||||
|
await this.refresh();
|
||||||
|
});
|
||||||
|
return [model.modelDetails.fileName, nameComponent, descriptionComponent, deleteButton];
|
||||||
|
}
|
||||||
|
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): RegisteredModelDetails {
|
public get data(): ModelViewData[] {
|
||||||
return {
|
return this.modelsViewData;
|
||||||
title: this._nameComponent?.value || '',
|
|
||||||
description: this._descriptionComponent?.value
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the component
|
* Returns the component
|
||||||
*/
|
*/
|
||||||
public get component(): azdata.Component | undefined {
|
public get component(): azdata.Component | undefined {
|
||||||
return this._form;
|
return this._table;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Refreshes the view
|
* Refreshes the view
|
||||||
*/
|
*/
|
||||||
public async refresh(): Promise<void> {
|
public async refresh(): Promise<void> {
|
||||||
|
await this.loadData();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,17 +4,16 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import { ModelViewBase } from './modelViewBase';
|
import { ModelViewBase, ModelViewData } from './modelViewBase';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
import { IPageView, IDataComponent } from '../interfaces';
|
import { IPageView, IDataComponent } from '../interfaces';
|
||||||
import { ModelDetailsComponent } from './modelDetailsComponent';
|
import { ModelDetailsComponent } from './modelDetailsComponent';
|
||||||
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to pick model details
|
* View to pick model details
|
||||||
*/
|
*/
|
||||||
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModelDetails> {
|
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
|
||||||
|
|
||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
private _formBuilder: azdata.FormBuilder | undefined;
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
@@ -31,9 +30,8 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
|||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
|
||||||
this._formBuilder = modelBuilder.formContainer();
|
this._formBuilder = modelBuilder.formContainer();
|
||||||
this.modelDetails = new ModelDetailsComponent(this._apiWrapper, this);
|
this.modelDetails = new ModelDetailsComponent(this._apiWrapper, modelBuilder, this);
|
||||||
this.modelDetails.registerComponent(modelBuilder);
|
this.modelDetails.registerComponent(modelBuilder);
|
||||||
|
|
||||||
this.modelDetails.addComponents(this._formBuilder);
|
this.modelDetails.addComponents(this._formBuilder);
|
||||||
this.refresh();
|
this.refresh();
|
||||||
this._form = this._formBuilder.component();
|
this._form = this._formBuilder.component();
|
||||||
@@ -43,7 +41,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
|||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): RegisteredModelDetails | undefined {
|
public get data(): ModelViewData[] | undefined {
|
||||||
return this.modelDetails?.data;
|
return this.modelDetails?.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,6 +56,13 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
|||||||
* Refreshes the view
|
* Refreshes the view
|
||||||
*/
|
*/
|
||||||
public async refresh(): Promise<void> {
|
public async refresh(): Promise<void> {
|
||||||
|
if (this.modelDetails) {
|
||||||
|
await this.modelDetails.refresh();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async onEnter(): Promise<void> {
|
||||||
|
await this.refresh();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -68,7 +73,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
|||||||
}
|
}
|
||||||
|
|
||||||
public validate(): Promise<boolean> {
|
public validate(): Promise<boolean> {
|
||||||
if (this.data && this.data.title) {
|
if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.title)) {
|
||||||
return Promise.resolve(true);
|
return Promise.resolve(true);
|
||||||
} else {
|
} else {
|
||||||
this.showErrorMessage(constants.modelNameRequiredError);
|
this.showErrorMessage(constants.modelNameRequiredError);
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ import { azureResource } from '../../typings/azure-resource';
|
|||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
|
import { RegisteredModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces';
|
||||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||||
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
||||||
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
|
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
|
||||||
import {
|
import {
|
||||||
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||||
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
|
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
|
||||||
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName
|
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData
|
||||||
} from './modelViewBase';
|
} from './modelViewBase';
|
||||||
import { ControllerBase } from '../controllerBase';
|
import { ControllerBase } from '../controllerBase';
|
||||||
import { RegisterModelWizard } from './registerModels/registerModelWizard';
|
import { RegisterModelWizard } from './registerModels/registerModelWizard';
|
||||||
@@ -122,17 +122,17 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService);
|
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService);
|
||||||
});
|
});
|
||||||
view.on(RegisterLocalModelEventName, async (arg) => {
|
view.on(RegisterLocalModelEventName, async (arg) => {
|
||||||
let registerArgs = <RegisterLocalModelEventArgs>arg;
|
let models = <ModelViewData[]>arg;
|
||||||
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, registerArgs.filePath, registerArgs.details);
|
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models);
|
||||||
view.refresh();
|
view.refresh();
|
||||||
});
|
});
|
||||||
view.on(RegisterModelEventName, async () => {
|
view.on(RegisterModelEventName, async () => {
|
||||||
await this.executeAction(view, RegisterModelEventName, this.registerModel, view, this, this._apiWrapper, this._root);
|
await this.executeAction(view, RegisterModelEventName, this.registerModel, view, this, this._apiWrapper, this._root);
|
||||||
});
|
});
|
||||||
view.on(RegisterAzureModelEventName, async (arg) => {
|
view.on(RegisterAzureModelEventName, async (arg) => {
|
||||||
let registerArgs = <RegisterAzureModelEventArgs>arg;
|
let models = <ModelViewData[]>arg;
|
||||||
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
||||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
|
models);
|
||||||
});
|
});
|
||||||
view.on(DownloadAzureModelEventName, async (arg) => {
|
view.on(DownloadAzureModelEventName, async (arg) => {
|
||||||
let registerArgs = <AzureModelResource>arg;
|
let registerArgs = <AzureModelResource>arg;
|
||||||
@@ -161,7 +161,8 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
|
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
|
||||||
model);
|
model);
|
||||||
});
|
});
|
||||||
view.on(SourceModelSelectedEventName, () => {
|
view.on(SourceModelSelectedEventName, (arg) => {
|
||||||
|
view.modelSourceType = <ModelSourceType>arg;
|
||||||
view.refresh();
|
view.refresh();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -217,35 +218,46 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
||||||
}
|
}
|
||||||
|
|
||||||
private async registerLocalModel(service: DeployedModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
|
private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise<void> {
|
||||||
if (filePath) {
|
if (models) {
|
||||||
await service.deployLocalModel(filePath, details);
|
await Promise.all(models.map(async (model) => {
|
||||||
|
const localModel = <string>model.modelData;
|
||||||
|
if (localModel) {
|
||||||
|
await service.deployLocalModel(localModel, model.modelDetails);
|
||||||
|
}
|
||||||
|
}));
|
||||||
} else {
|
} else {
|
||||||
throw Error(constants.invalidModelToRegisterError);
|
throw Error(constants.invalidModelToRegisterError);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async registerAzureModel(
|
private async registerAzureModel(
|
||||||
azureService: AzureModelRegistryService,
|
azureService: AzureModelRegistryService,
|
||||||
service: DeployedModelService,
|
service: DeployedModelService,
|
||||||
account: azdata.Account | undefined,
|
models: ModelViewData[] | undefined): Promise<void> {
|
||||||
subscription: azureResource.AzureResourceSubscription | undefined,
|
if (!models) {
|
||||||
resourceGroup: azureResource.AzureResource | undefined,
|
|
||||||
workspace: Workspace | undefined,
|
|
||||||
model: WorkspaceModel | undefined,
|
|
||||||
details: RegisteredModelDetails | undefined): Promise<void> {
|
|
||||||
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
|
|
||||||
throw Error(constants.invalidAzureResourceError);
|
throw Error(constants.invalidAzureResourceError);
|
||||||
}
|
}
|
||||||
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
|
|
||||||
if (filePath) {
|
|
||||||
|
|
||||||
await service.deployLocalModel(filePath, details);
|
await Promise.all(models.map(async (model) => {
|
||||||
await fs.promises.unlink(filePath);
|
const azureModel = <AzureModelResource>model.modelData;
|
||||||
} else {
|
if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) {
|
||||||
throw Error(constants.invalidModelToRegisterError);
|
let filePath: string | undefined;
|
||||||
}
|
try {
|
||||||
|
const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group,
|
||||||
|
azureModel.workspace, azureModel.model);
|
||||||
|
if (filePath) {
|
||||||
|
await service.deployLocalModel(filePath, model.modelDetails);
|
||||||
|
} else {
|
||||||
|
throw Error(constants.invalidModelToRegisterError);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
if (filePath) {
|
||||||
|
await fs.promises.unlink(filePath);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
||||||
|
|||||||
@@ -4,14 +4,11 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import { ModelViewBase } from './modelViewBase';
|
import { ModelViewBase, ModelSourceType } from './modelViewBase';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
import { IPageView, IDataComponent } from '../interfaces';
|
import { IPageView, IDataComponent } from '../interfaces';
|
||||||
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
|
import { ModelSourcesComponent } from './modelSourcesComponent';
|
||||||
import { LocalModelsComponent } from './localModelsComponent';
|
|
||||||
import { AzureModelsComponent } from './azureModelsComponent';
|
|
||||||
import { CurrentModelsTable } from './registerModels/currentModelsTable';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to pick model source
|
* View to pick model source
|
||||||
@@ -21,9 +18,6 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
private _formBuilder: azdata.FormBuilder | undefined;
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
public modelResources: ModelSourcesComponent | undefined;
|
public modelResources: ModelSourcesComponent | undefined;
|
||||||
public localModelsComponent: LocalModelsComponent | undefined;
|
|
||||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
|
||||||
public registeredModelsComponent: CurrentModelsTable | undefined;
|
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
@@ -38,14 +32,7 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
this._formBuilder = modelBuilder.formContainer();
|
this._formBuilder = modelBuilder.formContainer();
|
||||||
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
|
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
|
||||||
this.modelResources.registerComponent(modelBuilder);
|
this.modelResources.registerComponent(modelBuilder);
|
||||||
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
|
|
||||||
this.localModelsComponent.registerComponent(modelBuilder);
|
|
||||||
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
|
|
||||||
this.azureModelsComponent.registerComponent(modelBuilder);
|
|
||||||
this.modelResources.addComponents(this._formBuilder);
|
this.modelResources.addComponents(this._formBuilder);
|
||||||
this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this);
|
|
||||||
this.registeredModelsComponent.registerComponent(modelBuilder);
|
|
||||||
this.refresh();
|
|
||||||
this._form = this._formBuilder.component();
|
this._form = this._formBuilder.component();
|
||||||
return this._form;
|
return this._form;
|
||||||
}
|
}
|
||||||
@@ -68,33 +55,6 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
* Refreshes the view
|
* Refreshes the view
|
||||||
*/
|
*/
|
||||||
public async refresh(): Promise<void> {
|
public async refresh(): Promise<void> {
|
||||||
if (this._formBuilder) {
|
|
||||||
if (this.modelResources && this.modelResources.data === ModelSourceType.Local) {
|
|
||||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
|
||||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
|
||||||
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
|
||||||
this.localModelsComponent.addComponents(this._formBuilder);
|
|
||||||
await this.localModelsComponent.refresh();
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) {
|
|
||||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
|
||||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
|
||||||
this.azureModelsComponent.addComponents(this._formBuilder);
|
|
||||||
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
|
||||||
await this.azureModelsComponent.refresh();
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels) {
|
|
||||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
|
||||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
|
||||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
|
||||||
this.registeredModelsComponent.addComponents(this._formBuilder);
|
|
||||||
await this.registeredModelsComponent.refresh();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -104,30 +64,6 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
return constants.modelSourcePageTitle;
|
return constants.modelSourcePageTitle;
|
||||||
}
|
}
|
||||||
|
|
||||||
public validate(): Promise<boolean> {
|
|
||||||
let validated = false;
|
|
||||||
if (this.modelResources && this.modelResources.data === ModelSourceType.Local && this.localModelsComponent) {
|
|
||||||
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
|
|
||||||
|
|
||||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure && this.azureModelsComponent) {
|
|
||||||
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.model !== undefined;
|
|
||||||
|
|
||||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
|
|
||||||
validated = this.registeredModelsComponent.data !== undefined;
|
|
||||||
}
|
|
||||||
if (!validated) {
|
|
||||||
this.showErrorMessage(constants.invalidModelToSelectError);
|
|
||||||
}
|
|
||||||
return Promise.resolve(validated);
|
|
||||||
}
|
|
||||||
|
|
||||||
public async disposePage(): Promise<void> {
|
public async disposePage(): Promise<void> {
|
||||||
if (this.azureModelsComponent) {
|
|
||||||
await this.azureModelsComponent.disposeComponent();
|
|
||||||
|
|
||||||
}
|
|
||||||
if (this.registeredModelsComponent) {
|
|
||||||
await this.registeredModelsComponent.disposeComponent();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,16 +4,11 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import { ModelViewBase, SourceModelSelectedEventName } from './modelViewBase';
|
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
import { IDataComponent } from '../interfaces';
|
import { IDataComponent } from '../interfaces';
|
||||||
|
|
||||||
export enum ModelSourceType {
|
|
||||||
Local,
|
|
||||||
Azure,
|
|
||||||
RegisteredModels
|
|
||||||
}
|
|
||||||
/**
|
/**
|
||||||
* View to pick model source
|
* View to pick model source
|
||||||
*/
|
*/
|
||||||
@@ -21,9 +16,9 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
|
|
||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
private _flexContainer: azdata.FlexContainer | undefined;
|
private _flexContainer: azdata.FlexContainer | undefined;
|
||||||
private _amlModel: azdata.RadioButtonComponent | undefined;
|
private _amlModel: azdata.CardComponent | undefined;
|
||||||
private _localModel: azdata.RadioButtonComponent | undefined;
|
private _localModel: azdata.CardComponent | undefined;
|
||||||
private _registeredModels: azdata.RadioButtonComponent | undefined;
|
private _registeredModels: azdata.CardComponent | undefined;
|
||||||
private _sourceType: ModelSourceType = ModelSourceType.Local;
|
private _sourceType: ModelSourceType = ModelSourceType.Local;
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||||
@@ -35,45 +30,61 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
* @param modelBuilder Register components
|
* @param modelBuilder Register components
|
||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
this._localModel = modelBuilder.radioButton()
|
|
||||||
|
this._localModel = modelBuilder.card()
|
||||||
.withProperties({
|
.withProperties({
|
||||||
value: 'local',
|
value: 'local',
|
||||||
name: 'modelLocation',
|
name: 'modelLocation',
|
||||||
label: constants.localModelSource,
|
label: constants.localModelSource,
|
||||||
checked: this._options[0] === ModelSourceType.Local
|
selected: this._options[0] === ModelSourceType.Local,
|
||||||
|
cardType: azdata.CardType.VerticalButton,
|
||||||
|
width: 50
|
||||||
}).component();
|
}).component();
|
||||||
|
this._amlModel = modelBuilder.card()
|
||||||
|
|
||||||
this._amlModel = modelBuilder.radioButton()
|
|
||||||
.withProperties({
|
.withProperties({
|
||||||
value: 'aml',
|
value: 'aml',
|
||||||
name: 'modelLocation',
|
name: 'modelLocation',
|
||||||
label: constants.azureModelSource,
|
label: constants.azureModelSource,
|
||||||
checked: this._options[0] === ModelSourceType.Azure
|
selected: this._options[0] === ModelSourceType.Azure,
|
||||||
|
cardType: azdata.CardType.VerticalButton,
|
||||||
|
width: 50
|
||||||
}).component();
|
}).component();
|
||||||
|
|
||||||
this._registeredModels = modelBuilder.radioButton()
|
this._registeredModels = modelBuilder.card()
|
||||||
.withProperties({
|
.withProperties({
|
||||||
value: 'registered',
|
value: 'registered',
|
||||||
name: 'modelLocation',
|
name: 'modelLocation',
|
||||||
label: constants.registeredModelsSource,
|
label: constants.registeredModelsSource,
|
||||||
checked: this._options[0] === ModelSourceType.RegisteredModels
|
selected: this._options[0] === ModelSourceType.RegisteredModels,
|
||||||
|
cardType: azdata.CardType.VerticalButton,
|
||||||
|
width: 50
|
||||||
}).component();
|
}).component();
|
||||||
|
|
||||||
this._localModel.onDidClick(() => {
|
this._localModel.onCardSelectedChanged(() => {
|
||||||
this._sourceType = ModelSourceType.Local;
|
this._sourceType = ModelSourceType.Local;
|
||||||
this.sendRequest(SourceModelSelectedEventName);
|
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||||
|
if (this._amlModel && this._registeredModels) {
|
||||||
|
this._amlModel.selected = false;
|
||||||
|
this._registeredModels.selected = false;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
this._amlModel.onDidClick(() => {
|
this._amlModel.onCardSelectedChanged(() => {
|
||||||
this._sourceType = ModelSourceType.Azure;
|
this._sourceType = ModelSourceType.Azure;
|
||||||
this.sendRequest(SourceModelSelectedEventName);
|
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||||
|
if (this._localModel && this._registeredModels) {
|
||||||
|
this._localModel.selected = false;
|
||||||
|
this._registeredModels.selected = false;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
this._registeredModels.onDidClick(() => {
|
this._registeredModels.onCardSelectedChanged(() => {
|
||||||
this._sourceType = ModelSourceType.RegisteredModels;
|
this._sourceType = ModelSourceType.RegisteredModels;
|
||||||
this.sendRequest(SourceModelSelectedEventName);
|
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||||
|
if (this._localModel && this._amlModel) {
|
||||||
|
this._localModel.selected = false;
|
||||||
|
this._amlModel.selected = false;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
let components: azdata.RadioButtonComponent[] = [];
|
let components: azdata.Component[] = [];
|
||||||
|
|
||||||
this._options.forEach(option => {
|
this._options.forEach(option => {
|
||||||
switch (option) {
|
switch (option) {
|
||||||
@@ -95,10 +106,11 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
this._sourceType = this._options[0];
|
this._sourceType = this._options[0];
|
||||||
|
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||||
|
|
||||||
this._flexContainer = modelBuilder.flexContainer()
|
this._flexContainer = modelBuilder.flexContainer()
|
||||||
.withLayout({
|
.withLayout({
|
||||||
flexFlow: 'column',
|
flexFlow: 'row',
|
||||||
justifyContent: 'space-between'
|
justifyContent: 'space-between'
|
||||||
}).withItems(components).component();
|
}).withItems(components).component();
|
||||||
|
|
||||||
@@ -112,13 +124,13 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
|
|
||||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._flexContainer) {
|
if (this._flexContainer) {
|
||||||
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer });
|
formBuilder.addFormItem({ title: '', component: this._flexContainer });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._flexContainer) {
|
if (this._flexContainer) {
|
||||||
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer });
|
formBuilder.removeFormItem({ title: '', component: this._flexContainer });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/
|
|||||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
||||||
|
|
||||||
|
|
||||||
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
|
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,17 +21,22 @@ export interface RegisterModelEventArgs extends AzureWorkspaceResource {
|
|||||||
details?: RegisteredModelDetails
|
details?: RegisteredModelDetails
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs {
|
|
||||||
model?: WorkspaceModel;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface PredictModelEventArgs extends PredictParameters {
|
export interface PredictModelEventArgs extends PredictParameters {
|
||||||
model?: RegisteredModel;
|
model?: RegisteredModel;
|
||||||
filePath?: string;
|
filePath?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
|
|
||||||
filePath?: string;
|
export enum ModelSourceType {
|
||||||
|
Local,
|
||||||
|
Azure,
|
||||||
|
RegisteredModels
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelViewData {
|
||||||
|
modelFile?: string;
|
||||||
|
modelData: AzureModelResource | string | RegisteredModel;
|
||||||
|
modelDetails?: RegisteredModelDetails;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Event names
|
// Event names
|
||||||
@@ -58,6 +64,9 @@ export const LoadModelParametersEventName = 'loadModelParameters';
|
|||||||
*/
|
*/
|
||||||
export abstract class ModelViewBase extends ViewBase {
|
export abstract class ModelViewBase extends ViewBase {
|
||||||
|
|
||||||
|
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
|
||||||
|
private _modelsViewData: ModelViewData[] = [];
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
|
||||||
super(apiWrapper, root, parent);
|
super(apiWrapper, root, parent);
|
||||||
}
|
}
|
||||||
@@ -147,12 +156,8 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
* registers local model
|
* registers local model
|
||||||
* @param localFilePath local file path
|
* @param localFilePath local file path
|
||||||
*/
|
*/
|
||||||
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
|
public async registerLocalModel(models: ModelViewData[]): Promise<void> {
|
||||||
const args: RegisterLocalModelEventArgs = {
|
return await this.sendDataRequest(RegisterLocalModelEventName, models);
|
||||||
filePath: localFilePath,
|
|
||||||
details: details
|
|
||||||
};
|
|
||||||
return await this.sendDataRequest(RegisterLocalModelEventName, args);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -182,11 +187,8 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
* registers azure model
|
* registers azure model
|
||||||
* @param args azure resource
|
* @param args azure resource
|
||||||
*/
|
*/
|
||||||
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
|
public async registerAzureModel(models: ModelViewData[]): Promise<void> {
|
||||||
const args: RegisterAzureModelEventArgs = Object.assign({}, resource, {
|
return await this.sendDataRequest(RegisterAzureModelEventName, models);
|
||||||
details: details
|
|
||||||
});
|
|
||||||
return await this.sendDataRequest(RegisterAzureModelEventName, args);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -215,6 +217,50 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
return await this.sendDataRequest(ListGroupsEventName, args);
|
return await this.sendDataRequest(ListGroupsEventName, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets model source type
|
||||||
|
*/
|
||||||
|
public set modelSourceType(value: ModelSourceType) {
|
||||||
|
if (this.parent) {
|
||||||
|
this.parent.modelSourceType = value;
|
||||||
|
} else {
|
||||||
|
this._modelSourceType = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns model source type
|
||||||
|
*/
|
||||||
|
public get modelSourceType(): ModelSourceType {
|
||||||
|
if (this.parent) {
|
||||||
|
return this.parent.modelSourceType;
|
||||||
|
} else {
|
||||||
|
return this._modelSourceType;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets model source type
|
||||||
|
*/
|
||||||
|
public set modelsViewData(value: ModelViewData[]) {
|
||||||
|
if (this.parent) {
|
||||||
|
this.parent.modelsViewData = value;
|
||||||
|
} else {
|
||||||
|
this._modelsViewData = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns model source type
|
||||||
|
*/
|
||||||
|
public get modelsViewData(): ModelViewData[] {
|
||||||
|
if (this.parent) {
|
||||||
|
return this.parent.modelsViewData;
|
||||||
|
} else {
|
||||||
|
return this._modelsViewData;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* lists azure workspaces
|
* lists azure workspaces
|
||||||
* @param account azure account
|
* @param account azure account
|
||||||
|
|||||||
@@ -4,9 +4,9 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import { ModelViewBase } from '../modelViewBase';
|
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
|
||||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
|
import { ModelSourcesComponent } from '../modelSourcesComponent';
|
||||||
import { LocalModelsComponent } from '../localModelsComponent';
|
import { LocalModelsComponent } from '../localModelsComponent';
|
||||||
import { AzureModelsComponent } from '../azureModelsComponent';
|
import { AzureModelsComponent } from '../azureModelsComponent';
|
||||||
import * as constants from '../../../common/constants';
|
import * as constants from '../../../common/constants';
|
||||||
@@ -15,6 +15,7 @@ import { ModelSourcePage } from '../modelSourcePage';
|
|||||||
import { ColumnsSelectionPage } from './columnsSelectionPage';
|
import { ColumnsSelectionPage } from './columnsSelectionPage';
|
||||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
import { ModelArtifact } from './modelArtifact';
|
import { ModelArtifact } from './modelArtifact';
|
||||||
|
import { ModelBrowsePage } from '../modelBrowsePage';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wizard to register a model
|
* Wizard to register a model
|
||||||
@@ -23,6 +24,7 @@ export class PredictWizard extends ModelViewBase {
|
|||||||
|
|
||||||
public modelSourcePage: ModelSourcePage | undefined;
|
public modelSourcePage: ModelSourcePage | undefined;
|
||||||
public columnsSelectionPage: ColumnsSelectionPage | undefined;
|
public columnsSelectionPage: ColumnsSelectionPage | undefined;
|
||||||
|
public modelBrowsePage: ModelBrowsePage | undefined;
|
||||||
public wizardView: WizardView | undefined;
|
public wizardView: WizardView | undefined;
|
||||||
private _parentView: ModelViewBase | undefined;
|
private _parentView: ModelViewBase | undefined;
|
||||||
|
|
||||||
@@ -40,10 +42,12 @@ export class PredictWizard extends ModelViewBase {
|
|||||||
public async open(): Promise<void> {
|
public async open(): Promise<void> {
|
||||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
|
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
|
||||||
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
|
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
|
||||||
|
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false);
|
||||||
this.wizardView = new WizardView(this._apiWrapper);
|
this.wizardView = new WizardView(this._apiWrapper);
|
||||||
|
|
||||||
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
|
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
|
||||||
[this.modelSourcePage,
|
[this.modelSourcePage,
|
||||||
|
this.modelBrowsePage,
|
||||||
this.columnsSelectionPage]);
|
this.columnsSelectionPage]);
|
||||||
|
|
||||||
this.mainViewPanel = wizard;
|
this.mainViewPanel = wizard;
|
||||||
@@ -57,7 +61,10 @@ export class PredictWizard extends ModelViewBase {
|
|||||||
await this.onClose();
|
await this.onClose();
|
||||||
});
|
});
|
||||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||||
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
let validated: boolean = true;
|
||||||
|
if (pageInfo.newPage > pageInfo.lastPage) {
|
||||||
|
validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||||
|
}
|
||||||
if (validated) {
|
if (validated) {
|
||||||
if (pageInfo.newPage === undefined) {
|
if (pageInfo.newPage === undefined) {
|
||||||
this.onLoading();
|
this.onLoading();
|
||||||
@@ -96,20 +103,20 @@ export class PredictWizard extends ModelViewBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public get localModelsComponent(): LocalModelsComponent | undefined {
|
public get localModelsComponent(): LocalModelsComponent | undefined {
|
||||||
return this.modelSourcePage?.localModelsComponent;
|
return this.modelBrowsePage?.localModelsComponent;
|
||||||
}
|
}
|
||||||
|
|
||||||
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
||||||
return this.modelSourcePage?.azureModelsComponent;
|
return this.modelBrowsePage?.azureModelsComponent;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getModelFileName(): Promise<ModelArtifact | undefined> {
|
public async getModelFileName(): Promise<ModelArtifact | undefined> {
|
||||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||||
return new ModelArtifact(this.localModelsComponent.data, false);
|
return new ModelArtifact(this.localModelsComponent.data[0], false);
|
||||||
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
||||||
return await this.azureModelsComponent.getDownloadedModel();
|
return await this.azureModelsComponent.getDownloadedModel();
|
||||||
} else if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
|
} else if (this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) {
|
||||||
return await this.modelSourcePage.registeredModelsComponent.getDownloadedModel();
|
return await this.modelBrowsePage.registeredModelsComponent.getDownloadedModel();
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
@@ -118,8 +125,10 @@ export class PredictWizard extends ModelViewBase {
|
|||||||
try {
|
try {
|
||||||
let modelFilePath: string | undefined;
|
let modelFilePath: string | undefined;
|
||||||
let registeredModel: RegisteredModel | undefined = undefined;
|
let registeredModel: RegisteredModel | undefined = undefined;
|
||||||
if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
|
if (this.modelResources && this.modelResources.data && this.modelResources.data === ModelSourceType.RegisteredModels
|
||||||
registeredModel = this.modelSourcePage?.registeredModelsComponent?.data;
|
&& this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) {
|
||||||
|
const data = this.modelBrowsePage?.registeredModelsComponent?.data;
|
||||||
|
registeredModel = data && data.length > 0 ? data[0] : undefined;
|
||||||
} else {
|
} else {
|
||||||
const artifact = await this.getModelFileName();
|
const artifact = await this.getModelFileName();
|
||||||
modelFilePath = artifact?.filePath;
|
modelFilePath = artifact?.filePath;
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
|||||||
* @param modelBuilder register the components
|
* @param modelBuilder register the components
|
||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
this._dataTable = new CurrentModelsTable(this._apiWrapper, this);
|
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, false);
|
||||||
this._dataTable.registerComponent(modelBuilder);
|
this._dataTable.registerComponent(modelBuilder);
|
||||||
this._tableComponent = this._dataTable.component;
|
this._tableComponent = this._dataTable.component;
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ import { ModelArtifact } from '../prediction/modelArtifact';
|
|||||||
/**
|
/**
|
||||||
* View to render registered models table
|
* View to render registered models table
|
||||||
*/
|
*/
|
||||||
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<RegisteredModel> {
|
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<RegisteredModel[]> {
|
||||||
|
|
||||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||||
private _modelBuilder: azdata.ModelBuilder | undefined;
|
private _modelBuilder: azdata.ModelBuilder | undefined;
|
||||||
private _selectedModel: any;
|
private _selectedModel: RegisteredModel[] = [];
|
||||||
private _loader: azdata.LoadingComponent | undefined;
|
private _loader: azdata.LoadingComponent | undefined;
|
||||||
private _downloadedFile: ModelArtifact | undefined;
|
private _downloadedFile: ModelArtifact | undefined;
|
||||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||||
@@ -28,7 +28,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
/**
|
/**
|
||||||
* Creates new view
|
* Creates new view
|
||||||
*/
|
*/
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,17 +161,45 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
|
|
||||||
private createTableRow(model: RegisteredModel): any[] {
|
private createTableRow(model: RegisteredModel): any[] {
|
||||||
if (this._modelBuilder) {
|
if (this._modelBuilder) {
|
||||||
let selectModelButton = this._modelBuilder.radioButton().withProperties({
|
let selectModelButton: azdata.Component;
|
||||||
name: 'amlModel',
|
let onSelectItem = (checked: boolean) => {
|
||||||
value: model.id,
|
if (!this._multiSelect) {
|
||||||
width: 15,
|
this._selectedModel = [];
|
||||||
height: 15,
|
}
|
||||||
checked: false
|
const foundItem = this._selectedModel.find(x => x === model);
|
||||||
}).component();
|
if (checked && !foundItem) {
|
||||||
selectModelButton.onDidClick(async () => {
|
this._selectedModel.push(model);
|
||||||
this._selectedModel = model;
|
} else if (foundItem) {
|
||||||
await this.onModelSelected();
|
this._selectedModel = this._selectedModel.filter(x => x !== model);
|
||||||
});
|
}
|
||||||
|
this.onModelSelected();
|
||||||
|
};
|
||||||
|
if (this._multiSelect) {
|
||||||
|
const checkbox = this._modelBuilder.checkBox().withProperties({
|
||||||
|
name: 'amlModel',
|
||||||
|
value: model.id,
|
||||||
|
width: 15,
|
||||||
|
height: 15,
|
||||||
|
checked: false
|
||||||
|
}).component();
|
||||||
|
checkbox.onChanged(() => {
|
||||||
|
onSelectItem(checkbox.checked || false);
|
||||||
|
});
|
||||||
|
selectModelButton = checkbox;
|
||||||
|
} else {
|
||||||
|
const radioButton = this._modelBuilder.radioButton().withProperties({
|
||||||
|
name: 'amlModel',
|
||||||
|
value: model.id,
|
||||||
|
width: 15,
|
||||||
|
height: 15,
|
||||||
|
checked: false
|
||||||
|
}).component();
|
||||||
|
radioButton.onDidClick(() => {
|
||||||
|
onSelectItem(radioButton.checked || false);
|
||||||
|
});
|
||||||
|
selectModelButton = radioButton;
|
||||||
|
}
|
||||||
|
|
||||||
return [model.artifactName, model.title, model.created, selectModelButton];
|
return [model.artifactName, model.title, model.created, selectModelButton];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,13 +217,13 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): RegisteredModel | undefined {
|
public get data(): RegisteredModel[] | undefined {
|
||||||
return this._selectedModel;
|
return this._selectedModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getDownloadedModel(): Promise<ModelArtifact> {
|
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
|
||||||
if (!this._downloadedFile) {
|
if (!this._downloadedFile && this.data && this.data.length > 0) {
|
||||||
this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data));
|
this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data[0]));
|
||||||
}
|
}
|
||||||
return this._downloadedFile;
|
return this._downloadedFile;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,15 +4,16 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import { ModelViewBase } from '../modelViewBase';
|
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
|
||||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
|
import { ModelSourcesComponent } from '../modelSourcesComponent';
|
||||||
import { LocalModelsComponent } from '../localModelsComponent';
|
import { LocalModelsComponent } from '../localModelsComponent';
|
||||||
import { AzureModelsComponent } from '../azureModelsComponent';
|
import { AzureModelsComponent } from '../azureModelsComponent';
|
||||||
import * as constants from '../../../common/constants';
|
import * as constants from '../../../common/constants';
|
||||||
import { WizardView } from '../../wizardView';
|
import { WizardView } from '../../wizardView';
|
||||||
import { ModelSourcePage } from '../modelSourcePage';
|
import { ModelSourcePage } from '../modelSourcePage';
|
||||||
import { ModelDetailsPage } from '../modelDetailsPage';
|
import { ModelDetailsPage } from '../modelDetailsPage';
|
||||||
|
import { ModelBrowsePage } from '../modelBrowsePage';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wizard to register a model
|
* Wizard to register a model
|
||||||
@@ -20,6 +21,7 @@ import { ModelDetailsPage } from '../modelDetailsPage';
|
|||||||
export class RegisterModelWizard extends ModelViewBase {
|
export class RegisterModelWizard extends ModelViewBase {
|
||||||
|
|
||||||
public modelSourcePage: ModelSourcePage | undefined;
|
public modelSourcePage: ModelSourcePage | undefined;
|
||||||
|
public modelBrowsePage: ModelBrowsePage | undefined;
|
||||||
public modelDetailsPage: ModelDetailsPage | undefined;
|
public modelDetailsPage: ModelDetailsPage | undefined;
|
||||||
public wizardView: WizardView | undefined;
|
public wizardView: WizardView | undefined;
|
||||||
private _parentView: ModelViewBase | undefined;
|
private _parentView: ModelViewBase | undefined;
|
||||||
@@ -38,16 +40,20 @@ export class RegisterModelWizard extends ModelViewBase {
|
|||||||
public async open(): Promise<void> {
|
public async open(): Promise<void> {
|
||||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
|
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
|
||||||
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
|
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
|
||||||
|
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this);
|
||||||
this.wizardView = new WizardView(this._apiWrapper);
|
this.wizardView = new WizardView(this._apiWrapper);
|
||||||
|
|
||||||
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelDetailsPage]);
|
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]);
|
||||||
|
|
||||||
this.mainViewPanel = wizard;
|
this.mainViewPanel = wizard;
|
||||||
wizard.doneButton.label = constants.azureRegisterModel;
|
wizard.doneButton.label = constants.azureRegisterModel;
|
||||||
wizard.generateScriptButton.hidden = true;
|
wizard.generateScriptButton.hidden = true;
|
||||||
wizard.displayPageTitles = true;
|
wizard.displayPageTitles = true;
|
||||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||||
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
let validated: boolean = true;
|
||||||
|
if (pageInfo.newPage > pageInfo.lastPage) {
|
||||||
|
validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||||
|
}
|
||||||
if (validated && pageInfo.newPage === undefined) {
|
if (validated && pageInfo.newPage === undefined) {
|
||||||
wizard.cancelButton.enabled = false;
|
wizard.cancelButton.enabled = false;
|
||||||
wizard.backButton.enabled = false;
|
wizard.backButton.enabled = false;
|
||||||
@@ -71,19 +77,19 @@ export class RegisterModelWizard extends ModelViewBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public get localModelsComponent(): LocalModelsComponent | undefined {
|
public get localModelsComponent(): LocalModelsComponent | undefined {
|
||||||
return this.modelSourcePage?.localModelsComponent;
|
return this.modelBrowsePage?.localModelsComponent;
|
||||||
}
|
}
|
||||||
|
|
||||||
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
||||||
return this.modelSourcePage?.azureModelsComponent;
|
return this.modelBrowsePage?.azureModelsComponent;
|
||||||
}
|
}
|
||||||
|
|
||||||
private async registerModel(): Promise<boolean> {
|
private async registerModel(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||||
await this.registerLocalModel(this.localModelsComponent.data, this.modelDetailsPage?.data);
|
await this.registerLocalModel(this.modelsViewData);
|
||||||
} else {
|
} else {
|
||||||
await this.registerAzureModel(this.azureModelsComponent?.data, this.modelDetailsPage?.data);
|
await this.registerAzureModel(this.modelsViewData);
|
||||||
}
|
}
|
||||||
this.showInfoMessage(constants.modelRegisteredSuccessfully);
|
this.showInfoMessage(constants.modelRegisteredSuccessfully);
|
||||||
return true;
|
return true;
|
||||||
@@ -93,14 +99,10 @@ export class RegisterModelWizard extends ModelViewBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private loadPages(): void {
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Refresh the pages
|
* Refresh the pages
|
||||||
*/
|
*/
|
||||||
public async refresh(): Promise<void> {
|
public async refresh(): Promise<void> {
|
||||||
this.loadPages();
|
|
||||||
await this.wizardView?.refresh();
|
await this.wizardView?.refresh();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user