diff --git a/extensions/machine-learning-services/src/common/constants.ts b/extensions/machine-learning-services/src/common/constants.ts index 565a111475..175eaa80d0 100644 --- a/extensions/machine-learning-services/src/common/constants.ts +++ b/extensions/machine-learning-services/src/common/constants.ts @@ -117,6 +117,7 @@ export const extLangUpdateFailedError = localize('extLang.updateFailedError', "F export const modelArtifactName = localize('models.artifactName', "Artifact Name"); export const modelName = localize('models.name', "Name"); +export const modelFileName = localize('models.fileName', "File"); export const modelDescription = localize('models.description', "Description"); export const modelCreated = localize('models.created', "Date Created"); export const modelVersion = localize('models.version', "Version"); @@ -139,9 +140,9 @@ export const azureModels = localize('models.azureModels', "Models"); export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models"); export const localModelsTitle = localize('models.localModelsTitle', "Local models"); export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location"); -export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Enter model source details"); +export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?"); export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input"); -export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details"); +export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details"); export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file"); export const currentModelsTitle = localize('models.currentModelsTitle', "Models"); export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy"); @@ -151,9 +152,9 @@ export const deployModelTitle = localize('models.deployModelTitle', "Deploy mode export const makePredictionTitle = localize('models.makePredictionTitle', "Make prediction"); export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully"); export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register"); -export const localModelSource = localize('models.localModelSource', "Upload file"); -export const azureModelSource = localize('models.azureModelSource', "Import from AzureML registry"); -export const registeredModelsSource = localize('models.registeredModelsSource', "Select managed models"); +export const localModelSource = localize('models.localModelSource', "File upload"); +export const azureModelSource = localize('models.azureModelSource', "Azure Machine Learning"); +export const registeredModelsSource = localize('models.registeredModelsSource', "Imported models"); export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure"); export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource"); export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register"); @@ -161,7 +162,8 @@ export const invalidModelToPredictError = localize('models.invalidModelToPredict export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model"); export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); -export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model"); +export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); } + export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'"); export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported"); diff --git a/extensions/machine-learning-services/src/common/utils.ts b/extensions/machine-learning-services/src/common/utils.ts index 96c8697c38..aac4966ada 100644 --- a/extensions/machine-learning-services/src/common/utils.ts +++ b/extensions/machine-learning-services/src/common/utils.ts @@ -248,3 +248,15 @@ export async function writeFileFromHex(content: string): Promise { await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex')); return tempFilePath; } + +/** + * + * @param filePath Returns file name + */ +export function getFileName(filePath: string) { + if (filePath) { + return filePath.replace(/^.*[\\\/]/, ''); + } else { + return ''; + } +} diff --git a/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts index 7bee455235..b3ac7161db 100644 --- a/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts +++ b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts @@ -86,22 +86,23 @@ export class DeployedModelService { let connection = await this.getCurrentConnection(); if (connection) { let currentModels = await this.getDeployedModels(); - await this._modelClient.deployModel(connection, filePath); - let updatedModels = await this.getDeployedModels(); - if (details && updatedModels.length >= currentModels.length + 1) { - updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0); - const addedModel = updatedModels[updatedModels.length - 1]; - addedModel.title = details.title; - addedModel.description = details.description; - addedModel.version = details.version; - const updatedModel = await this.updateModel(addedModel); - if (!updatedModel) { - throw Error(constants.updateModelFailedError); - } + const content = await utils.readFileInHex(filePath); + const fileName = details?.fileName || utils.getFileName(filePath); + let modelToAdd: RegisteredModel = { + id: 0, + artifactName: fileName, + content: content, + title: details?.title || fileName, + description: details?.description, + version: details?.version + }; + await this._queryRunner.safeRunQuery(connection, this.getInsertModelQuery(connection.databaseName, modelToAdd)); - } else { - throw Error(constants.importModelFailedError); + let updatedModels = await this.getDeployedModels(); + if (updatedModels.length < currentModels.length + 1) { + throw Error(constants.importModelFailedError(details?.title, filePath)); } + } } private loadModelData(row: azdata.DbCellValue[]): RegisteredModel { @@ -115,20 +116,6 @@ export class DeployedModelService { }; } - private async updateModel(model: RegisteredModel): Promise { - let connection = await this.getCurrentConnection(); - let updatedModel: RegisteredModel | undefined = undefined; - if (connection) { - const query = this.getUpdateModelQuery(connection.databaseName, model); - let result = await this._queryRunner.safeRunQuery(connection, query); - if (result?.rows && result.rows.length > 0) { - const row = result.rows[0]; - updatedModel = this.loadModelData(row); - } - } - return updatedModel; - } - private async getCurrentConnection(): Promise { return await this._apiWrapper.getCurrentConnection(); } @@ -190,7 +177,7 @@ export class DeployedModelService { CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}( [artifact_id] [int] IDENTITY(1,1) NOT NULL, [artifact_name] [varchar](256) NOT NULL, - [group_path] [varchar](256) NOT NULL, + [group_path] [varchar](256) NULL, [artifact_content] [varbinary](max) NOT NULL, [artifact_initial_size] [bigint] NULL, [name] [varchar](256) NULL, @@ -207,20 +194,24 @@ export class DeployedModelService { `; } - public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string { + public getInsertModelQuery(currentDatabaseName: string, model: RegisteredModel): string { let updateScript = ` - UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)} - SET - name = '${utils.doubleEscapeSingleQuotes(model.title || '')}', - version = '${utils.doubleEscapeSingleQuotes(model.version || '')}', - description = '${utils.doubleEscapeSingleQuotes(model.description || '')}' - WHERE artifact_id = ${model.id}`; + Insert into ${utils.getRegisteredModelsTowPartsName(this._config)} + (artifact_name, group_path, artifact_content, name, version, description) + values ( + '${utils.doubleEscapeSingleQuotes(model.artifactName || '')}', + 'ADS', + ${utils.doubleEscapeSingleQuotes(model.content || '')}, + '${utils.doubleEscapeSingleQuotes(model.title || '')}', + '${utils.doubleEscapeSingleQuotes(model.version || '')}', + '${utils.doubleEscapeSingleQuotes(model.description || '')}') + `; return ` ${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)} SELECT artifact_id, artifact_name, name, description, version, created FROM ${utils.getRegisteredModelsThreePartsName(this._config)} - WHERE artifact_id = ${model.id}; + WHERE artifact_id = SCOPE_IDENTITY(); `; } diff --git a/extensions/machine-learning-services/src/modelManagement/interfaces.ts b/extensions/machine-learning-services/src/modelManagement/interfaces.ts index f827bffc34..39a697fcfe 100644 --- a/extensions/machine-learning-services/src/modelManagement/interfaces.ts +++ b/extensions/machine-learning-services/src/modelManagement/interfaces.ts @@ -51,6 +51,7 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & { export interface RegisteredModel extends RegisteredModelDetails { id: number; artifactName: string; + content?: string; } export interface ModelParameter { @@ -71,6 +72,7 @@ export interface RegisteredModelDetails { created?: string; version?: string; description?: string; + fileName?: string; } /** @@ -230,3 +232,4 @@ export interface ArtifactAPIGetArtifactContentInformation2OptionalParams extends */ accountName?: string; } + diff --git a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts index f324ed11b1..6d90c665df 100644 --- a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts +++ b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts @@ -248,9 +248,10 @@ describe('DeployedModelService', () => { testContext.config.object, testContext.queryRunner.object, testContext.modelClient.object); - testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => { + + testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => { deployed = true; - return Promise.resolve(); + return Promise.resolve(result); }); testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result); @@ -259,7 +260,15 @@ describe('DeployedModelService', () => { testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table'); testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); - await should(service.deployLocalModel('', model)).resolved(); + let tempFilePath: string = ''; + try { + tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`); + await fs.promises.writeFile(tempFilePath, 'test'); + await should(service.deployLocalModel(tempFilePath, model)).resolved(); + } + finally { + await utils.deleteFile(tempFilePath); + } }); it('getConfigureQuery should escape db name', async function (): Promise { @@ -306,7 +315,7 @@ describe('DeployedModelService', () => { CREATE TABLE [dbo].[ta[[b]]le]( [artifact_id] [int] IDENTITY(1,1) NOT NULL, [artifact_name] [varchar](256) NOT NULL, - [group_path] [varchar](256) NOT NULL, + [group_path] [varchar](256) NULL, [artifact_content] [varbinary](max) NOT NULL, [artifact_initial_size] [bigint] NULL, [name] [varchar](256) NULL, @@ -345,7 +354,7 @@ describe('DeployedModelService', () => { should.deepEqual(expected, actual); }); - it('getUpdateModelQuery should escape db name', async function (): Promise { + it('getInsertModelQuery should escape db name', async function (): Promise { const testContext = createContext(); const dbName = 'curre[n]tDb'; const model: RegisteredModel = @@ -367,16 +376,17 @@ describe('DeployedModelService', () => { testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); const expected = ` - UPDATE [dbo].[ta[[b]]le] - SET - name = 'title1', - version = '1.1', - description = 'desc1' - WHERE artifact_id = 1`; - const actual = service.getUpdateModelQuery(dbName, model); + Insert into [dbo].[ta[[b]]le] + (artifact_name, group_path, artifact_content, name, version, description) + values ( + 'name1', + 'ADS', + , + 'title1', + '1.1', + 'desc1')`; + const actual = service.getInsertModelQuery(dbName, model); should.equal(actual.indexOf(expected) > 0, true); - //should.deepEqual(actual, expected); - }); it('getModelContentQuery should escape db name', async function (): Promise { diff --git a/extensions/machine-learning-services/src/test/views/models/azureModelsComponent.test.ts b/extensions/machine-learning-services/src/test/views/models/azureModelsComponent.test.ts index 85d918e946..e3cb1a78b2 100644 --- a/extensions/machine-learning-services/src/test/views/models/azureModelsComponent.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/azureModelsComponent.test.ts @@ -28,7 +28,7 @@ describe('Azure Models Component', () => { let testContext = createContext(); let parent = new ParentDialog(testContext.apiWrapper.object); - let view = new AzureModelsComponent(testContext.apiWrapper.object, parent); + let view = new AzureModelsComponent(testContext.apiWrapper.object, parent, false); view.registerComponent(testContext.view.modelBuilder); let accounts: azdata.Account[] = [ @@ -88,12 +88,15 @@ describe('Azure Models Component', () => { parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models }); }); await view.refresh(); - testContext.onClick.fire(); + testContext.onClick.fire(true); should.notEqual(view.data, undefined); - should.deepEqual(view.data?.account, accounts[0]); - should.deepEqual(view.data?.subscription, subscriptions[0]); - should.deepEqual(view.data?.group, groups[0]); - should.deepEqual(view.data?.workspace, workspaces[0]); - should.deepEqual(view.data?.model, models[0]); + should.equal(view.data?.length, 1); + if (view.data) { + should.deepEqual(view.data[0].account, accounts[0]); + should.deepEqual(view.data[0].subscription, subscriptions[0]); + should.deepEqual(view.data[0].group, groups[0]); + should.deepEqual(view.data[0].workspace, workspaces[0]); + should.deepEqual(view.data[0].model, models[0]); + } }); }); diff --git a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts index c5b40ceb11..77d4b8dc64 100644 --- a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts @@ -9,7 +9,7 @@ import 'mocha'; import { createContext } from './utils'; import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, - ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName + ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType } from '../../../views/models/modelViewBase'; import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces'; @@ -164,9 +164,25 @@ describe('Predict Wizard', () => { view.on(DownloadRegisteredModelEventName, () => { view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' }); }); + if (view.modelBrowsePage) { + view.modelBrowsePage.modelSourceType = ModelSourceType.Azure; + } await view.refresh(); should.notEqual(view.azureModelsComponent?.data, undefined); + + if (view.modelBrowsePage) { + view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels; + } + await view.refresh(); + testContext.onClick.fire(); + + should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels); should.notEqual(view.localModelsComponent?.data, undefined); + should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined); + if (view.modelBrowsePage?.registeredModelsComponent?.data) { + should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1); + } + should.notEqual(await view.getModelFileName(), undefined); await view.columnsSelectionPage?.onEnter(); diff --git a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts index d4f542fc0c..0bf7a7b9de 100644 --- a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts @@ -7,7 +7,7 @@ import * as azdata from 'azdata'; import * as should from 'should'; import 'mocha'; import { createContext } from './utils'; -import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName } from '../../../views/models/modelViewBase'; +import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType } from '../../../views/models/modelViewBase'; import { RegisteredModel } from '../../../modelManagement/interfaces'; import { azureResource } from '../../../typings/azure-resource'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; @@ -97,6 +97,10 @@ describe('Register Model Wizard', () => { view.on(ListAzureModelsEventName, () => { view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models }); }); + + if (view.modelBrowsePage) { + view.modelBrowsePage.modelSourceType = ModelSourceType.Azure; + } await view.refresh(); should.notEqual(view.azureModelsComponent?.data ,undefined); should.notEqual(view.localModelsComponent?.data, undefined); diff --git a/extensions/machine-learning-services/src/test/views/utils.ts b/extensions/machine-learning-services/src/test/views/utils.ts index 0d7467203f..e55c545cba 100644 --- a/extensions/machine-learning-services/src/test/views/utils.ts +++ b/extensions/machine-learning-services/src/test/views/utils.ts @@ -32,8 +32,13 @@ export function createViewContext(): ViewTestContext { onDidClick: onClick.event }); let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, { + checked: true, onDidClick: onClick.event }); + let checkbox: azdata.CheckBoxComponent = Object.assign({}, componentBase, { + checked: true, + onChanged: onClick.event + }); let container = { clearItems: () => { }, addItems: () => { }, @@ -58,6 +63,11 @@ export function createViewContext(): ViewTestContext { withProperties: () => radioButtonBuilder, withValidation: () => radioButtonBuilder }; + let checkBoxBuilder: azdata.ComponentBuilder = { + component: () => checkbox, + withProperties: () => checkBoxBuilder, + withValidation: () => checkBoxBuilder + }; let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, { onTextChanged: undefined!, onEnterKeyPressed: undefined!, @@ -85,6 +95,12 @@ export function createViewContext(): ViewTestContext { component: undefined! }); + let card: () => azdata.CardComponent = () => Object.assign({}, componentBase, { + label: '', + onDidActionClick: new vscode.EventEmitter().event, + onCardSelectedChanged: onClick.event + }); + let declarativeTableBuilder: azdata.ComponentBuilder = { component: () => declarativeTable(), withProperties: () => declarativeTableBuilder, @@ -127,6 +143,15 @@ export function createViewContext(): ViewTestContext { withProperties: () => inputBoxBuilder, withValidation: () => inputBoxBuilder }; + let cardBuilder: azdata.ComponentBuilder = { + component: () => { + let r = card(); + return r; + }, + withProperties: () => cardBuilder, + withValidation: () => cardBuilder + }; + let imageBuilder: azdata.ComponentBuilder = { component: () => { let r = image(); @@ -159,9 +184,9 @@ export function createViewContext(): ViewTestContext { flexContainer: () => flexBuilder, splitViewContainer: undefined!, dom: undefined!, - card: undefined!, + card: () => cardBuilder, inputBox: () => inputBoxBuilder, - checkBox: undefined!, + checkBox: () => checkBoxBuilder!, radioButton: () => radioButtonBuilder, webView: undefined!, editor: undefined!, diff --git a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts index 3dde5de665..8b36f31a2e 100644 --- a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts @@ -11,7 +11,7 @@ import { AzureModelsTable } from './azureModelsTable'; import { IDataComponent, AzureModelResource } from '../interfaces'; import { ModelArtifact } from './prediction/modelArtifact'; -export class AzureModelsComponent extends ModelViewBase implements IDataComponent { +export class AzureModelsComponent extends ModelViewBase implements IDataComponent { public azureModelsTable: AzureModelsTable | undefined; public azureFilterComponent: AzureResourceFilterComponent | undefined; @@ -23,7 +23,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen /** * Component to render a view to pick an azure model */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { super(apiWrapper, parent.root, parent); } @@ -33,7 +33,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this); - this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this); + this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect); this._loader = modelBuilder.loadingComponent() .withItem(this.azureModelsTable.component) .withProperties({ @@ -109,15 +109,16 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen /** * Returns selected data */ - public get data(): AzureModelResource | undefined { - return Object.assign({}, this.azureFilterComponent?.data, { - model: this.azureModelsTable?.data - }); + public get data(): AzureModelResource[] | undefined { + return this.azureModelsTable?.data ? this.azureModelsTable?.data.map(x => Object.assign({}, this.azureFilterComponent?.data, { + model: x + })) : undefined; } - public async getDownloadedModel(): Promise { - if (!this._downloadedFile) { - this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(this.data)); + public async getDownloadedModel(): Promise { + const data = this.data; + if (!this._downloadedFile && data && data.length > 0) { + this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(data[0])); } return this._downloadedFile; } diff --git a/extensions/machine-learning-services/src/views/models/azureModelsTable.ts b/extensions/machine-learning-services/src/views/models/azureModelsTable.ts index ea128c7cf7..eeeb581669 100644 --- a/extensions/machine-learning-services/src/views/models/azureModelsTable.ts +++ b/extensions/machine-learning-services/src/views/models/azureModelsTable.ts @@ -14,10 +14,10 @@ import { IDataComponent, AzureWorkspaceResource } from '../interfaces'; /** * View to render azure models in a table */ -export class AzureModelsTable extends ModelViewBase implements IDataComponent { +export class AzureModelsTable extends ModelViewBase implements IDataComponent { private _table: azdata.DeclarativeTableComponent; - private _selectedModelId: any; + private _selectedModel: WorkspaceModel[] = []; private _models: WorkspaceModel[] | undefined; private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter(); public readonly onModelSelectionChanged: vscode.Event = this._onModelSelectionChanged.event; @@ -25,7 +25,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent { - this._selectedModelId = model.id; + let selectModelButton: azdata.Component; + let onSelectItem = (checked: boolean) => { + const foundItem = this._selectedModel.find(x => x === model); + if (checked && !foundItem) { + this._selectedModel.push(model); + } else if (foundItem) { + this._selectedModel = this._selectedModel.filter(x => x !== model); + } this._onModelSelectionChanged.fire(); - }); + }; + if (this._multiSelect) { + const checkbox = this._modelBuilder.checkBox().withProperties({ + name: 'amlModel', + value: model.id, + width: 15, + height: 15, + checked: false + }).component(); + checkbox.onChanged(() => { + onSelectItem(checkbox.checked || false); + }); + selectModelButton = checkbox; + } else { + const radioButton = this._modelBuilder.radioButton().withProperties({ + name: 'amlModel', + value: model.id, + width: 15, + height: 15, + checked: false + }).component(); + radioButton.onDidClick(() => { + onSelectItem(radioButton.checked || false); + }); + selectModelButton = radioButton; + } + return [model.name, model.createdTime, model.frameworkVersion, selectModelButton]; } @@ -143,9 +168,9 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent x.id === this._selectedModelId); + public get data(): WorkspaceModel[] | undefined { + if (this._models && this._selectedModel) { + return this._selectedModel; } return undefined; } diff --git a/extensions/machine-learning-services/src/views/models/localModelsComponent.ts b/extensions/machine-learning-services/src/views/models/localModelsComponent.ts index 41eec22ada..b74f786282 100644 --- a/extensions/machine-learning-services/src/views/models/localModelsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/localModelsComponent.ts @@ -14,7 +14,7 @@ import { IDataComponent } from '../interfaces'; /** * View to pick local models file */ -export class LocalModelsComponent extends ModelViewBase implements IDataComponent { +export class LocalModelsComponent extends ModelViewBase implements IDataComponent { private _form: azdata.FormContainer | undefined; private _flex: azdata.FlexContainer | undefined; @@ -24,7 +24,7 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen /** * Creates new view */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { super(apiWrapper, parent.root, parent); } @@ -49,13 +49,15 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen let options: vscode.OpenDialogOptions = { canSelectFiles: true, canSelectFolders: false, - canSelectMany: false, + canSelectMany: this._multiSelect, filters: { 'ONNX File': ['onnx'] } }; const filePaths = await this.getLocalPaths(options); - if (this._localPath) { - this._localPath.value = filePaths && filePaths.length > 0 ? filePaths[0] : ''; + if (this._localPath && filePaths && filePaths.length > 0) { + this._localPath.value = this._multiSelect ? filePaths.join(';') : filePaths[0]; + } else if (this._localPath) { + this._localPath.value = ''; } }); @@ -96,8 +98,12 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen /** * Returns selected data */ - public get data(): string { - return this._localPath?.value || ''; + public get data(): string[] { + if (this._localPath?.value) { + return this._localPath?.value.split(';'); + } else { + return []; + } } /** diff --git a/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts b/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts new file mode 100644 index 0000000000..3bd3b6fc3a --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts @@ -0,0 +1,178 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase, ModelSourceType, ModelViewData } from './modelViewBase'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as constants from '../../common/constants'; +import { IPageView, IDataComponent } from '../interfaces'; +import { LocalModelsComponent } from './localModelsComponent'; +import { AzureModelsComponent } from './azureModelsComponent'; +import { CurrentModelsTable } from './registerModels/currentModelsTable'; +import * as utils from '../../common/utils'; + +/** + * View to pick model source + */ +export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _formBuilder: azdata.FormBuilder | undefined; + public localModelsComponent: LocalModelsComponent | undefined; + public azureModelsComponent: AzureModelsComponent | undefined; + public registeredModelsComponent: CurrentModelsTable | undefined; + + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { + super(apiWrapper, parent.root, parent); + } + + /** + * + * @param modelBuilder Register components + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + + this._formBuilder = modelBuilder.formContainer(); + this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this, this._multiSelect); + this.localModelsComponent.registerComponent(modelBuilder); + this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect); + this.azureModelsComponent.registerComponent(modelBuilder); + this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect); + this.registeredModelsComponent.registerComponent(modelBuilder); + this.refresh(); + this._form = this._formBuilder.component(); + return this._form; + } + + /** + * Returns selected data + */ + public get data(): ModelViewData[] { + return this.modelsViewData; + } + + /** + * Returns the component + */ + public get component(): azdata.Component | undefined { + return this._form; + } + + /** + * Refreshes the view + */ + public async refresh(): Promise { + if (this._formBuilder) { + if (this.modelSourceType === ModelSourceType.Local) { + if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { + this.azureModelsComponent.removeComponents(this._formBuilder); + this.registeredModelsComponent.removeComponents(this._formBuilder); + this.localModelsComponent.addComponents(this._formBuilder); + await this.localModelsComponent.refresh(); + } + + } else if (this.modelSourceType === ModelSourceType.Azure) { + if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { + this.localModelsComponent.removeComponents(this._formBuilder); + this.azureModelsComponent.addComponents(this._formBuilder); + this.registeredModelsComponent.removeComponents(this._formBuilder); + await this.azureModelsComponent.refresh(); + } + + } else if (this.modelSourceType === ModelSourceType.RegisteredModels) { + if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { + this.localModelsComponent.removeComponents(this._formBuilder); + this.azureModelsComponent.removeComponents(this._formBuilder); + this.registeredModelsComponent.addComponents(this._formBuilder); + await this.registeredModelsComponent.refresh(); + } + + } + } + } + + /** + * Returns page title + */ + public get title(): string { + return constants.modelSourcePageTitle; + } + + public validate(): Promise { + let validated = false; + if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) { + validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0; + + } else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) { + validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0; + + } else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) { + validated = this.registeredModelsComponent.data !== undefined && this.registeredModelsComponent.data.length > 0; + } + if (!validated) { + this.showErrorMessage(constants.invalidModelToSelectError); + } + return Promise.resolve(validated); + } + + public async onLeave(): Promise { + this.modelsViewData = []; + if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) { + if (this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0) { + this.modelsViewData = this.localModelsComponent.data.map(x => { + const fileName = utils.getFileName(x); + return { + modelData: x, + modelDetails: { + title: fileName, + fileName: fileName + } + }; + }); + } + + } else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) { + if (this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0) { + this.modelsViewData = this.azureModelsComponent.data.map(x => { + return { + modelData: { + account: x.account, + subscription: x.subscription, + group: x.group, + workspace: x.workspace, + model: x.model + }, + modelDetails: { + title: x.model?.name || '', + fileName: x.model?.name + } + }; + }); + } + + } else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) { + if (this.registeredModelsComponent.data !== undefined) { + this.modelsViewData = this.registeredModelsComponent.data.map(x => { + return { + modelData: x, + modelDetails: { + title: '' + } + }; + }); + } + } + } + + public async disposePage(): Promise { + if (this.azureModelsComponent) { + await this.azureModelsComponent.disposeComponent(); + + } + if (this.registeredModelsComponent) { + await this.registeredModelsComponent.disposeComponent(); + } + } +} diff --git a/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts b/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts index 3465ff0b95..9c9e0b1de5 100644 --- a/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts @@ -4,25 +4,21 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase } from './modelViewBase'; +import { ModelViewBase, ModelViewData } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IDataComponent } from '../interfaces'; -import { RegisteredModelDetails } from '../../modelManagement/interfaces'; /** * View to pick local models file */ -export class ModelDetailsComponent extends ModelViewBase implements IDataComponent { - - private _form: azdata.FormContainer | undefined; - private _nameComponent: azdata.InputBoxComponent | undefined; - private _descriptionComponent: azdata.InputBoxComponent | undefined; +export class ModelDetailsComponent extends ModelViewBase implements IDataComponent { + private _table: azdata.DeclarativeTableComponent | undefined; /** * Creates new view */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) { super(apiWrapper, parent.root, parent); } @@ -31,73 +27,162 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone * @param modelBuilder Register the components */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._nameComponent = modelBuilder.inputBox().withProperties({ - value: '', - width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength - }).component(); - this._descriptionComponent = modelBuilder.inputBox().withProperties({ - value: '', - multiline: true, - width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength, - hight: '50px' - }).component(); + this._table = modelBuilder.declarativeTable() + .withProperties( + { + columns: [ + { // Name + displayName: constants.modelFileName, + ariaLabel: constants.modelFileName, + valueType: azdata.DeclarativeDataType.string, + isReadOnly: true, + width: 150, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Name + displayName: constants.modelName, + ariaLabel: constants.modelName, + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 150, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Created + displayName: constants.modelDescription, + ariaLabel: constants.modelDescription, + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 100, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Action + displayName: '', + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + } + ], + data: [], + ariaLabel: constants.mlsConfigTitle + }) + .component(); - this._form = modelBuilder.formContainer().withFormItems([{ - title: constants.modelName, - component: this._nameComponent - }, { - title: constants.modelDescription, - component: this._descriptionComponent - }]).component(); - return this._form; + return this._table; } public addComponents(formBuilder: azdata.FormBuilder) { - if (this._nameComponent && this._descriptionComponent) { + if (this._table) { formBuilder.addFormItems([{ - title: constants.modelName, - component: this._nameComponent - }, { - title: constants.modelDescription, - component: this._descriptionComponent + title: '', + component: this._table }]); } } public removeComponents(formBuilder: azdata.FormBuilder) { - if (this._nameComponent && this._descriptionComponent) { + if (this._table) { formBuilder.removeFormItem({ - title: constants.modelName, - component: this._nameComponent - }); - formBuilder.removeFormItem({ - title: constants.modelDescription, - component: this._descriptionComponent + title: '', + component: this._table }); } } + /** + * Load data in the component + * @param workspaceResource Azure workspace + */ + public async loadData(): Promise { + + const models = this.modelsViewData; + if (this._table && models) { + + let tableData: any[][] = []; + tableData = tableData.concat(models.map(model => this.createTableRow(model))); + this._table.data = tableData; + } + } + + private createTableRow(model: ModelViewData | undefined): any[] { + if (this._modelBuilder && model && model.modelDetails) { + const nameComponent = this._modelBuilder.inputBox().withProperties({ + value: model.modelDetails.title, + width: this.componentMaxLength - 100, + required: true + }).component(); + const descriptionComponent = this._modelBuilder.inputBox().withProperties({ + value: model.modelDetails.description, + width: this.componentMaxLength + }).component(); + descriptionComponent.onTextChanged(() => { + if (model.modelDetails) { + model.modelDetails.description = descriptionComponent.value; + } + }); + nameComponent.onTextChanged(() => { + if (model.modelDetails) { + model.modelDetails.title = nameComponent.value || ''; + } + }); + let deleteButton = this._modelBuilder.button().withProperties({ + label: '', + title: constants.deleteTitle, + width: 15, + height: 15, + iconPath: { + dark: this.asAbsolutePath('images/dark/delete_inverse.svg'), + light: this.asAbsolutePath('images/light/delete.svg') + }, + }).component(); + deleteButton.onDidClick(async () => { + this.modelsViewData = this.modelsViewData.filter(x => x !== model); + await this.refresh(); + }); + return [model.modelDetails.fileName, nameComponent, descriptionComponent, deleteButton]; + } + + return []; + } /** * Returns selected data */ - public get data(): RegisteredModelDetails { - return { - title: this._nameComponent?.value || '', - description: this._descriptionComponent?.value - }; + public get data(): ModelViewData[] { + return this.modelsViewData; } /** * Returns the component */ public get component(): azdata.Component | undefined { - return this._form; + return this._table; } /** * Refreshes the view */ public async refresh(): Promise { + await this.loadData(); } } diff --git a/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts b/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts index f2baa27bd6..1e8fbb9779 100644 --- a/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts +++ b/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts @@ -4,17 +4,16 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase } from './modelViewBase'; +import { ModelViewBase, ModelViewData } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IPageView, IDataComponent } from '../interfaces'; import { ModelDetailsComponent } from './modelDetailsComponent'; -import { RegisteredModelDetails } from '../../modelManagement/interfaces'; /** * View to pick model details */ -export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent { +export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent { private _form: azdata.FormContainer | undefined; private _formBuilder: azdata.FormBuilder | undefined; @@ -31,9 +30,8 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this._formBuilder = modelBuilder.formContainer(); - this.modelDetails = new ModelDetailsComponent(this._apiWrapper, this); + this.modelDetails = new ModelDetailsComponent(this._apiWrapper, modelBuilder, this); this.modelDetails.registerComponent(modelBuilder); - this.modelDetails.addComponents(this._formBuilder); this.refresh(); this._form = this._formBuilder.component(); @@ -43,7 +41,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC /** * Returns selected data */ - public get data(): RegisteredModelDetails | undefined { + public get data(): ModelViewData[] | undefined { return this.modelDetails?.data; } @@ -58,6 +56,13 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC * Refreshes the view */ public async refresh(): Promise { + if (this.modelDetails) { + await this.modelDetails.refresh(); + } + } + + public async onEnter(): Promise { + await this.refresh(); } /** @@ -68,7 +73,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC } public validate(): Promise { - if (this.data && this.data.title) { + if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.title)) { return Promise.resolve(true); } else { this.showErrorMessage(constants.modelNameRequiredError); diff --git a/extensions/machine-learning-services/src/views/models/modelManagementController.ts b/extensions/machine-learning-services/src/views/models/modelManagementController.ts index 97c4bcfe71..b31884008a 100644 --- a/extensions/machine-learning-services/src/views/models/modelManagementController.ts +++ b/extensions/machine-learning-services/src/views/models/modelManagementController.ts @@ -9,15 +9,15 @@ import { azureResource } from '../../typings/azure-resource'; import { ApiWrapper } from '../../common/apiWrapper'; import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; -import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces'; +import { RegisteredModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { DeployedModelService } from '../../modelManagement/deployedModelService'; import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog'; import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, - ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, - RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, - ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName + ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName, + ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, + ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData } from './modelViewBase'; import { ControllerBase } from '../controllerBase'; import { RegisterModelWizard } from './registerModels/registerModelWizard'; @@ -122,17 +122,17 @@ export class ModelManagementController extends ControllerBase { await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService); }); view.on(RegisterLocalModelEventName, async (arg) => { - let registerArgs = arg; - await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, registerArgs.filePath, registerArgs.details); + let models = arg; + await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models); view.refresh(); }); view.on(RegisterModelEventName, async () => { await this.executeAction(view, RegisterModelEventName, this.registerModel, view, this, this._apiWrapper, this._root); }); view.on(RegisterAzureModelEventName, async (arg) => { - let registerArgs = arg; + let models = arg; await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService, - registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details); + models); }); view.on(DownloadAzureModelEventName, async (arg) => { let registerArgs = arg; @@ -161,7 +161,8 @@ export class ModelManagementController extends ControllerBase { await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService, model); }); - view.on(SourceModelSelectedEventName, () => { + view.on(SourceModelSelectedEventName, (arg) => { + view.modelSourceType = arg; view.refresh(); }); } @@ -217,35 +218,46 @@ export class ModelManagementController extends ControllerBase { return await service.getModels(account, subscription, resourceGroup, workspace) || []; } - private async registerLocalModel(service: DeployedModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise { - if (filePath) { - await service.deployLocalModel(filePath, details); + private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise { + if (models) { + await Promise.all(models.map(async (model) => { + const localModel = model.modelData; + if (localModel) { + await service.deployLocalModel(localModel, model.modelDetails); + } + })); } else { throw Error(constants.invalidModelToRegisterError); - } } private async registerAzureModel( azureService: AzureModelRegistryService, service: DeployedModelService, - account: azdata.Account | undefined, - subscription: azureResource.AzureResourceSubscription | undefined, - resourceGroup: azureResource.AzureResource | undefined, - workspace: Workspace | undefined, - model: WorkspaceModel | undefined, - details: RegisteredModelDetails | undefined): Promise { - if (!account || !subscription || !resourceGroup || !workspace || !model || !details) { + models: ModelViewData[] | undefined): Promise { + if (!models) { throw Error(constants.invalidAzureResourceError); } - const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model); - if (filePath) { - await service.deployLocalModel(filePath, details); - await fs.promises.unlink(filePath); - } else { - throw Error(constants.invalidModelToRegisterError); - } + await Promise.all(models.map(async (model) => { + const azureModel = model.modelData; + if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) { + let filePath: string | undefined; + try { + const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group, + azureModel.workspace, azureModel.model); + if (filePath) { + await service.deployLocalModel(filePath, model.modelDetails); + } else { + throw Error(constants.invalidModelToRegisterError); + } + } finally { + if (filePath) { + await fs.promises.unlink(filePath); + } + } + } + })); } public async getDatabaseList(predictService: PredictService): Promise { diff --git a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts index c8b046b5c7..5f156d1ea2 100644 --- a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts +++ b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts @@ -4,14 +4,11 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase } from './modelViewBase'; +import { ModelViewBase, ModelSourceType } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IPageView, IDataComponent } from '../interfaces'; -import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent'; -import { LocalModelsComponent } from './localModelsComponent'; -import { AzureModelsComponent } from './azureModelsComponent'; -import { CurrentModelsTable } from './registerModels/currentModelsTable'; +import { ModelSourcesComponent } from './modelSourcesComponent'; /** * View to pick model source @@ -21,9 +18,6 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo private _form: azdata.FormContainer | undefined; private _formBuilder: azdata.FormBuilder | undefined; public modelResources: ModelSourcesComponent | undefined; - public localModelsComponent: LocalModelsComponent | undefined; - public azureModelsComponent: AzureModelsComponent | undefined; - public registeredModelsComponent: CurrentModelsTable | undefined; constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) { super(apiWrapper, parent.root, parent); @@ -38,14 +32,7 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo this._formBuilder = modelBuilder.formContainer(); this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options); this.modelResources.registerComponent(modelBuilder); - this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this); - this.localModelsComponent.registerComponent(modelBuilder); - this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this); - this.azureModelsComponent.registerComponent(modelBuilder); this.modelResources.addComponents(this._formBuilder); - this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this); - this.registeredModelsComponent.registerComponent(modelBuilder); - this.refresh(); this._form = this._formBuilder.component(); return this._form; } @@ -68,33 +55,6 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo * Refreshes the view */ public async refresh(): Promise { - if (this._formBuilder) { - if (this.modelResources && this.modelResources.data === ModelSourceType.Local) { - if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { - this.azureModelsComponent.removeComponents(this._formBuilder); - this.registeredModelsComponent.removeComponents(this._formBuilder); - this.localModelsComponent.addComponents(this._formBuilder); - await this.localModelsComponent.refresh(); - } - - } else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) { - if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { - this.localModelsComponent.removeComponents(this._formBuilder); - this.azureModelsComponent.addComponents(this._formBuilder); - this.registeredModelsComponent.removeComponents(this._formBuilder); - await this.azureModelsComponent.refresh(); - } - - } else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels) { - if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) { - this.localModelsComponent.removeComponents(this._formBuilder); - this.azureModelsComponent.removeComponents(this._formBuilder); - this.registeredModelsComponent.addComponents(this._formBuilder); - await this.registeredModelsComponent.refresh(); - } - - } - } } /** @@ -104,30 +64,6 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo return constants.modelSourcePageTitle; } - public validate(): Promise { - let validated = false; - if (this.modelResources && this.modelResources.data === ModelSourceType.Local && this.localModelsComponent) { - validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0; - - } else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure && this.azureModelsComponent) { - validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.model !== undefined; - - } else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels && this.registeredModelsComponent) { - validated = this.registeredModelsComponent.data !== undefined; - } - if (!validated) { - this.showErrorMessage(constants.invalidModelToSelectError); - } - return Promise.resolve(validated); - } - public async disposePage(): Promise { - if (this.azureModelsComponent) { - await this.azureModelsComponent.disposeComponent(); - - } - if (this.registeredModelsComponent) { - await this.registeredModelsComponent.disposeComponent(); - } } } diff --git a/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts b/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts index ef542c58df..aa75f6f153 100644 --- a/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts +++ b/extensions/machine-learning-services/src/views/models/modelSourcesComponent.ts @@ -4,16 +4,11 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase, SourceModelSelectedEventName } from './modelViewBase'; +import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IDataComponent } from '../interfaces'; -export enum ModelSourceType { - Local, - Azure, - RegisteredModels -} /** * View to pick model source */ @@ -21,9 +16,9 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone private _form: azdata.FormContainer | undefined; private _flexContainer: azdata.FlexContainer | undefined; - private _amlModel: azdata.RadioButtonComponent | undefined; - private _localModel: azdata.RadioButtonComponent | undefined; - private _registeredModels: azdata.RadioButtonComponent | undefined; + private _amlModel: azdata.CardComponent | undefined; + private _localModel: azdata.CardComponent | undefined; + private _registeredModels: azdata.CardComponent | undefined; private _sourceType: ModelSourceType = ModelSourceType.Local; constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) { @@ -35,45 +30,61 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone * @param modelBuilder Register components */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._localModel = modelBuilder.radioButton() + + this._localModel = modelBuilder.card() .withProperties({ value: 'local', name: 'modelLocation', label: constants.localModelSource, - checked: this._options[0] === ModelSourceType.Local + selected: this._options[0] === ModelSourceType.Local, + cardType: azdata.CardType.VerticalButton, + width: 50 }).component(); - - - this._amlModel = modelBuilder.radioButton() + this._amlModel = modelBuilder.card() .withProperties({ value: 'aml', name: 'modelLocation', label: constants.azureModelSource, - checked: this._options[0] === ModelSourceType.Azure + selected: this._options[0] === ModelSourceType.Azure, + cardType: azdata.CardType.VerticalButton, + width: 50 }).component(); - this._registeredModels = modelBuilder.radioButton() + this._registeredModels = modelBuilder.card() .withProperties({ value: 'registered', name: 'modelLocation', label: constants.registeredModelsSource, - checked: this._options[0] === ModelSourceType.RegisteredModels + selected: this._options[0] === ModelSourceType.RegisteredModels, + cardType: azdata.CardType.VerticalButton, + width: 50 }).component(); - this._localModel.onDidClick(() => { + this._localModel.onCardSelectedChanged(() => { this._sourceType = ModelSourceType.Local; - this.sendRequest(SourceModelSelectedEventName); - + this.sendRequest(SourceModelSelectedEventName, this._sourceType); + if (this._amlModel && this._registeredModels) { + this._amlModel.selected = false; + this._registeredModels.selected = false; + } }); - this._amlModel.onDidClick(() => { + this._amlModel.onCardSelectedChanged(() => { this._sourceType = ModelSourceType.Azure; - this.sendRequest(SourceModelSelectedEventName); + this.sendRequest(SourceModelSelectedEventName, this._sourceType); + if (this._localModel && this._registeredModels) { + this._localModel.selected = false; + this._registeredModels.selected = false; + } }); - this._registeredModels.onDidClick(() => { + this._registeredModels.onCardSelectedChanged(() => { this._sourceType = ModelSourceType.RegisteredModels; - this.sendRequest(SourceModelSelectedEventName); + this.sendRequest(SourceModelSelectedEventName, this._sourceType); + if (this._localModel && this._amlModel) { + this._localModel.selected = false; + this._amlModel.selected = false; + } }); - let components: azdata.RadioButtonComponent[] = []; + let components: azdata.Component[] = []; this._options.forEach(option => { switch (option) { @@ -95,10 +106,11 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone } }); this._sourceType = this._options[0]; + this.sendRequest(SourceModelSelectedEventName, this._sourceType); this._flexContainer = modelBuilder.flexContainer() .withLayout({ - flexFlow: 'column', + flexFlow: 'row', justifyContent: 'space-between' }).withItems(components).component(); @@ -112,13 +124,13 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone public addComponents(formBuilder: azdata.FormBuilder) { if (this._flexContainer) { - formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer }); + formBuilder.addFormItem({ title: '', component: this._flexContainer }); } } public removeComponents(formBuilder: azdata.FormBuilder) { if (this._flexContainer) { - formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer }); + formBuilder.removeFormItem({ title: '', component: this._flexContainer }); } } diff --git a/extensions/machine-learning-services/src/views/models/modelViewBase.ts b/extensions/machine-learning-services/src/views/models/modelViewBase.ts index 5b69ecc19b..f1c4cd4b2b 100644 --- a/extensions/machine-learning-services/src/views/models/modelViewBase.ts +++ b/extensions/machine-learning-services/src/views/models/modelViewBase.ts @@ -13,6 +13,7 @@ import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/ import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { AzureWorkspaceResource, AzureModelResource } from '../interfaces'; + export interface AzureResourceEventArgs extends AzureWorkspaceResource { } @@ -20,17 +21,22 @@ export interface RegisterModelEventArgs extends AzureWorkspaceResource { details?: RegisteredModelDetails } -export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs { - model?: WorkspaceModel; -} - export interface PredictModelEventArgs extends PredictParameters { model?: RegisteredModel; filePath?: string; } -export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs { - filePath?: string; + +export enum ModelSourceType { + Local, + Azure, + RegisteredModels +} + +export interface ModelViewData { + modelFile?: string; + modelData: AzureModelResource | string | RegisteredModel; + modelDetails?: RegisteredModelDetails; } // Event names @@ -58,6 +64,9 @@ export const LoadModelParametersEventName = 'loadModelParameters'; */ export abstract class ModelViewBase extends ViewBase { + private _modelSourceType: ModelSourceType = ModelSourceType.Local; + private _modelsViewData: ModelViewData[] = []; + constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) { super(apiWrapper, root, parent); } @@ -147,12 +156,8 @@ export abstract class ModelViewBase extends ViewBase { * registers local model * @param localFilePath local file path */ - public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModelDetails | undefined): Promise { - const args: RegisterLocalModelEventArgs = { - filePath: localFilePath, - details: details - }; - return await this.sendDataRequest(RegisterLocalModelEventName, args); + public async registerLocalModel(models: ModelViewData[]): Promise { + return await this.sendDataRequest(RegisterLocalModelEventName, models); } /** @@ -182,11 +187,8 @@ export abstract class ModelViewBase extends ViewBase { * registers azure model * @param args azure resource */ - public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModelDetails | undefined): Promise { - const args: RegisterAzureModelEventArgs = Object.assign({}, resource, { - details: details - }); - return await this.sendDataRequest(RegisterAzureModelEventName, args); + public async registerAzureModel(models: ModelViewData[]): Promise { + return await this.sendDataRequest(RegisterAzureModelEventName, models); } /** @@ -215,6 +217,50 @@ export abstract class ModelViewBase extends ViewBase { return await this.sendDataRequest(ListGroupsEventName, args); } + /** + * Sets model source type + */ + public set modelSourceType(value: ModelSourceType) { + if (this.parent) { + this.parent.modelSourceType = value; + } else { + this._modelSourceType = value; + } + } + + /** + * Returns model source type + */ + public get modelSourceType(): ModelSourceType { + if (this.parent) { + return this.parent.modelSourceType; + } else { + return this._modelSourceType; + } + } + + /** + * Sets model source type + */ + public set modelsViewData(value: ModelViewData[]) { + if (this.parent) { + this.parent.modelsViewData = value; + } else { + this._modelsViewData = value; + } + } + + /** + * Returns model source type + */ + public get modelsViewData(): ModelViewData[] { + if (this.parent) { + return this.parent.modelsViewData; + } else { + return this._modelsViewData; + } + } + /** * lists azure workspaces * @param account azure account diff --git a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts index 61d41bebf0..82eecb1127 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts @@ -4,9 +4,9 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase } from '../modelViewBase'; +import { ModelViewBase, ModelSourceType } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; -import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent'; +import { ModelSourcesComponent } from '../modelSourcesComponent'; import { LocalModelsComponent } from '../localModelsComponent'; import { AzureModelsComponent } from '../azureModelsComponent'; import * as constants from '../../../common/constants'; @@ -15,6 +15,7 @@ import { ModelSourcePage } from '../modelSourcePage'; import { ColumnsSelectionPage } from './columnsSelectionPage'; import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ModelArtifact } from './modelArtifact'; +import { ModelBrowsePage } from '../modelBrowsePage'; /** * Wizard to register a model @@ -23,6 +24,7 @@ export class PredictWizard extends ModelViewBase { public modelSourcePage: ModelSourcePage | undefined; public columnsSelectionPage: ColumnsSelectionPage | undefined; + public modelBrowsePage: ModelBrowsePage | undefined; public wizardView: WizardView | undefined; private _parentView: ModelViewBase | undefined; @@ -40,10 +42,12 @@ export class PredictWizard extends ModelViewBase { public async open(): Promise { this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]); this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this); + this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false); this.wizardView = new WizardView(this._apiWrapper); let wizard = this.wizardView.createWizard(constants.makePredictionTitle, [this.modelSourcePage, + this.modelBrowsePage, this.columnsSelectionPage]); this.mainViewPanel = wizard; @@ -57,7 +61,10 @@ export class PredictWizard extends ModelViewBase { await this.onClose(); }); wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => { - let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; + let validated: boolean = true; + if (pageInfo.newPage > pageInfo.lastPage) { + validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; + } if (validated) { if (pageInfo.newPage === undefined) { this.onLoading(); @@ -96,20 +103,20 @@ export class PredictWizard extends ModelViewBase { } public get localModelsComponent(): LocalModelsComponent | undefined { - return this.modelSourcePage?.localModelsComponent; + return this.modelBrowsePage?.localModelsComponent; } public get azureModelsComponent(): AzureModelsComponent | undefined { - return this.modelSourcePage?.azureModelsComponent; + return this.modelBrowsePage?.azureModelsComponent; } public async getModelFileName(): Promise { if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) { - return new ModelArtifact(this.localModelsComponent.data, false); + return new ModelArtifact(this.localModelsComponent.data[0], false); } else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) { return await this.azureModelsComponent.getDownloadedModel(); - } else if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) { - return await this.modelSourcePage.registeredModelsComponent.getDownloadedModel(); + } else if (this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) { + return await this.modelBrowsePage.registeredModelsComponent.getDownloadedModel(); } return undefined; } @@ -118,8 +125,10 @@ export class PredictWizard extends ModelViewBase { try { let modelFilePath: string | undefined; let registeredModel: RegisteredModel | undefined = undefined; - if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) { - registeredModel = this.modelSourcePage?.registeredModelsComponent?.data; + if (this.modelResources && this.modelResources.data && this.modelResources.data === ModelSourceType.RegisteredModels + && this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) { + const data = this.modelBrowsePage?.registeredModelsComponent?.data; + registeredModel = data && data.length > 0 ? data[0] : undefined; } else { const artifact = await this.getModelFileName(); modelFilePath = artifact?.filePath; diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts index 2723c82dc6..c05aa078d1 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts @@ -33,7 +33,7 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView { * @param modelBuilder register the components */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._dataTable = new CurrentModelsTable(this._apiWrapper, this); + this._dataTable = new CurrentModelsTable(this._apiWrapper, this, false); this._dataTable.registerComponent(modelBuilder); this._tableComponent = this._dataTable.component; diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts index 7efd3bdc8b..1da086f8f3 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts @@ -15,11 +15,11 @@ import { ModelArtifact } from '../prediction/modelArtifact'; /** * View to render registered models table */ -export class CurrentModelsTable extends ModelViewBase implements IDataComponent { +export class CurrentModelsTable extends ModelViewBase implements IDataComponent { private _table: azdata.DeclarativeTableComponent | undefined; private _modelBuilder: azdata.ModelBuilder | undefined; - private _selectedModel: any; + private _selectedModel: RegisteredModel[] = []; private _loader: azdata.LoadingComponent | undefined; private _downloadedFile: ModelArtifact | undefined; private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter(); @@ -28,7 +28,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< /** * Creates new view */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { super(apiWrapper, parent.root, parent); } @@ -161,17 +161,45 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< private createTableRow(model: RegisteredModel): any[] { if (this._modelBuilder) { - let selectModelButton = this._modelBuilder.radioButton().withProperties({ - name: 'amlModel', - value: model.id, - width: 15, - height: 15, - checked: false - }).component(); - selectModelButton.onDidClick(async () => { - this._selectedModel = model; - await this.onModelSelected(); - }); + let selectModelButton: azdata.Component; + let onSelectItem = (checked: boolean) => { + if (!this._multiSelect) { + this._selectedModel = []; + } + const foundItem = this._selectedModel.find(x => x === model); + if (checked && !foundItem) { + this._selectedModel.push(model); + } else if (foundItem) { + this._selectedModel = this._selectedModel.filter(x => x !== model); + } + this.onModelSelected(); + }; + if (this._multiSelect) { + const checkbox = this._modelBuilder.checkBox().withProperties({ + name: 'amlModel', + value: model.id, + width: 15, + height: 15, + checked: false + }).component(); + checkbox.onChanged(() => { + onSelectItem(checkbox.checked || false); + }); + selectModelButton = checkbox; + } else { + const radioButton = this._modelBuilder.radioButton().withProperties({ + name: 'amlModel', + value: model.id, + width: 15, + height: 15, + checked: false + }).component(); + radioButton.onDidClick(() => { + onSelectItem(radioButton.checked || false); + }); + selectModelButton = radioButton; + } + return [model.artifactName, model.title, model.created, selectModelButton]; } @@ -189,13 +217,13 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< /** * Returns selected data */ - public get data(): RegisteredModel | undefined { + public get data(): RegisteredModel[] | undefined { return this._selectedModel; } - public async getDownloadedModel(): Promise { - if (!this._downloadedFile) { - this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data)); + public async getDownloadedModel(): Promise { + if (!this._downloadedFile && this.data && this.data.length > 0) { + this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data[0])); } return this._downloadedFile; } diff --git a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts index 2d42bf00e7..d965ee6c19 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts @@ -4,15 +4,16 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase } from '../modelViewBase'; +import { ModelViewBase, ModelSourceType } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; -import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent'; +import { ModelSourcesComponent } from '../modelSourcesComponent'; import { LocalModelsComponent } from '../localModelsComponent'; import { AzureModelsComponent } from '../azureModelsComponent'; import * as constants from '../../../common/constants'; import { WizardView } from '../../wizardView'; import { ModelSourcePage } from '../modelSourcePage'; import { ModelDetailsPage } from '../modelDetailsPage'; +import { ModelBrowsePage } from '../modelBrowsePage'; /** * Wizard to register a model @@ -20,6 +21,7 @@ import { ModelDetailsPage } from '../modelDetailsPage'; export class RegisterModelWizard extends ModelViewBase { public modelSourcePage: ModelSourcePage | undefined; + public modelBrowsePage: ModelBrowsePage | undefined; public modelDetailsPage: ModelDetailsPage | undefined; public wizardView: WizardView | undefined; private _parentView: ModelViewBase | undefined; @@ -38,16 +40,20 @@ export class RegisterModelWizard extends ModelViewBase { public async open(): Promise { this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this); this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this); + this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this); this.wizardView = new WizardView(this._apiWrapper); - let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelDetailsPage]); + let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]); this.mainViewPanel = wizard; wizard.doneButton.label = constants.azureRegisterModel; wizard.generateScriptButton.hidden = true; wizard.displayPageTitles = true; wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => { - let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; + let validated: boolean = true; + if (pageInfo.newPage > pageInfo.lastPage) { + validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; + } if (validated && pageInfo.newPage === undefined) { wizard.cancelButton.enabled = false; wizard.backButton.enabled = false; @@ -71,19 +77,19 @@ export class RegisterModelWizard extends ModelViewBase { } public get localModelsComponent(): LocalModelsComponent | undefined { - return this.modelSourcePage?.localModelsComponent; + return this.modelBrowsePage?.localModelsComponent; } public get azureModelsComponent(): AzureModelsComponent | undefined { - return this.modelSourcePage?.azureModelsComponent; + return this.modelBrowsePage?.azureModelsComponent; } private async registerModel(): Promise { try { if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) { - await this.registerLocalModel(this.localModelsComponent.data, this.modelDetailsPage?.data); + await this.registerLocalModel(this.modelsViewData); } else { - await this.registerAzureModel(this.azureModelsComponent?.data, this.modelDetailsPage?.data); + await this.registerAzureModel(this.modelsViewData); } this.showInfoMessage(constants.modelRegisteredSuccessfully); return true; @@ -93,14 +99,10 @@ export class RegisterModelWizard extends ModelViewBase { } } - private loadPages(): void { - } - /** * Refresh the pages */ public async refresh(): Promise { - this.loadPages(); await this.wizardView?.refresh(); } }