diff --git a/extensions/machine-learning/src/common/constants.ts b/extensions/machine-learning/src/common/constants.ts index 9df654e946..018c340dab 100644 --- a/extensions/machine-learning/src/common/constants.ts +++ b/extensions/machine-learning/src/common/constants.ts @@ -194,7 +194,12 @@ export const azureModels = localize('models.azureModels', "Models"); export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models"); export const localModelsTitle = localize('models.localModelsTitle', "Local models"); export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location"); -export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?"); +export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Select model source type"); +export const localModelSourceDescriptionForImport = localize('models.localModelSourceDescriptionForImport', "‘File Upload’ is selected. This allows you to import a model file from your local machine into a model database in this SQL instance. Click ‘Next’ to continue.​"); +export const azureModelSourceDescriptionForImport = localize('models.azureModelSourceDescriptionForImport', "‘Azure Machine Learning’ is selected. This allows you to import models stored in Azure Machine Learning workspaces in a model database in this SQL instance. Click ‘Next’ to continue.​​"); +export const localModelSourceDescriptionForPredict = localize('models.localModelSourceDescriptionForPredict', "‘File Upload’ is selected. This allows you to upload a model file from your local machine. Click ‘Next’ to continue.​​"); +export const importedModelSourceDescriptionForPredict = localize('models.importedModelSourceDescriptionForPredict', "‘Imported Models’ is selected. This allows you to choose from models stored in a model table in your database. Click ‘Next’ to continue.​"); +export const azureModelSourceDescriptionForPredict = localize('models.azureModelSourceDescriptionForPredict', "‘Azure Machine Learning’ is selected. This allows you to choose from models stored in Azure Machine Learning workspaces. Click ‘Next’ to continue.​"); export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Select or enter the location to import the models to"); export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map source data to model"); export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details"); @@ -202,7 +207,7 @@ export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "S export const modelLocalSourceTooltip = localize('models.modelLocalSourceTooltip', "File paths of the models to import"); export const onnxNotSupportedError = localize('models.onnxNotSupportedError', "ONNX runtime is not supported in current server"); export const currentModelsTitle = localize('models.currentModelsTitle', "Models"); -export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy"); +export const importModelDoneButton = localize('models.importModelDoneButton', "Import"); export const predictModel = localize('models.predictModel', "Predict"); export const registerModelTitle = localize('models.RegisterWizard', "Import models"); export const importedModelTitle = localize('models.importedModelTitle', "Imported models"); diff --git a/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts b/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts index 6cd7cc9725..bcd7b6afb6 100644 --- a/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts +++ b/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts @@ -4,7 +4,7 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase, ModelSourceType } from '../modelViewBase'; +import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; import { ModelSourcesComponent } from '../modelSourcesComponent'; import { LocalModelsComponent } from '../localModelsComponent'; @@ -34,6 +34,7 @@ export class ImportModelWizard extends ModelViewBase { parent?: ModelViewBase) { super(apiWrapper, root); this._parentView = parent; + this.modelActionType = ModelActionType.Import; } /** @@ -49,7 +50,7 @@ export class ImportModelWizard extends ModelViewBase { let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage, this.modelImportTargetPage]); this.mainViewPanel = wizard; - wizard.doneButton.label = constants.azureRegisterModel; + wizard.doneButton.label = constants.importModelDoneButton; wizard.generateScriptButton.hidden = true; wizard.displayPageTitles = true; wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => { diff --git a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts index d644f9ecc4..c8598ef27f 100644 --- a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts +++ b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts @@ -4,7 +4,7 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase'; +import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType, ModelActionType } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IDataComponent } from '../interfaces'; @@ -16,10 +16,12 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone private _form: azdata.FormContainer | undefined; private _flexContainer: azdata.FlexContainer | undefined; - private _amlModel: azdata.CardComponent | undefined; - private _localModel: azdata.CardComponent | undefined; - private _registeredModels: azdata.CardComponent | undefined; + private _amlModel: azdata.RadioCard | undefined; + private _localModel: azdata.RadioCard | undefined; + private _registeredModels: azdata.RadioCard | undefined; private _sourceType: ModelSourceType = ModelSourceType.Local; + private _defaultSourceType = ModelSourceType.Local; + private _selectedSourceLabel: azdata.TextComponent | undefined; constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) { super(apiWrapper, parent.root, parent); @@ -31,65 +33,44 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._sourceType = this._options && this._options.length > 0 ? this._options[0] : ModelSourceType.Local; + this._sourceType = this._options && this._options.length > 0 ? this._options[0] : this._defaultSourceType; this.modelSourceType = this._sourceType; - this._localModel = modelBuilder.card() - .withProperties({ - value: 'local', - name: 'modelLocation', - label: constants.localModelSource, - selected: this._sourceType === ModelSourceType.Local, - cardType: azdata.CardType.VerticalButton, - iconPath: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') }, - width: 50 - }).component(); - this._amlModel = modelBuilder.card() - .withProperties({ - value: 'aml', - name: 'modelLocation', - label: constants.azureModelSource, - selected: this._sourceType === ModelSourceType.Azure, - cardType: azdata.CardType.VerticalButton, - iconPath: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') }, - width: 50 - }).component(); + let selectedCardId: string = this.convertSourceIdToString(this._sourceType); - this._registeredModels = modelBuilder.card() - .withProperties({ - value: 'registered', - name: 'modelLocation', - label: constants.registeredModelsSource, - selected: this._sourceType === ModelSourceType.RegisteredModels, - cardType: azdata.CardType.VerticalButton, - iconPath: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') }, - width: 50 - }).component(); + this._localModel = { + descriptions: [{ + textValue: constants.localModelSource, + textStyles: { + 'font-size': '14px' + } + }], + id: this.convertSourceIdToString(ModelSourceType.Local), + icon: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') } + }; + this._amlModel = { + descriptions: [{ + textValue: constants.azureModelSource, + textStyles: { + 'font-size': '14px' + } + }], - this._localModel.onCardSelectedChanged(() => { - this._sourceType = ModelSourceType.Local; - this.sendRequest(SourceModelSelectedEventName, this._sourceType); - if (this._amlModel && this._registeredModels) { - this._amlModel.selected = false; - this._registeredModels.selected = false; - } - }); - this._amlModel.onCardSelectedChanged(() => { - this._sourceType = ModelSourceType.Azure; - this.sendRequest(SourceModelSelectedEventName, this._sourceType); - if (this._localModel && this._registeredModels) { - this._localModel.selected = false; - this._registeredModels.selected = false; - } - }); - this._registeredModels.onCardSelectedChanged(() => { - this._sourceType = ModelSourceType.RegisteredModels; - this.sendRequest(SourceModelSelectedEventName, this._sourceType); - if (this._localModel && this._amlModel) { - this._localModel.selected = false; - this._amlModel.selected = false; - } - }); - let components: azdata.Component[] = []; + id: this.convertSourceIdToString(ModelSourceType.Azure), + icon: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') } + }; + + this._registeredModels = { + descriptions: [{ + textValue: constants.registeredModelsSource, + textStyles: { + 'font-size': '14px' + } + }], + id: this.convertSourceIdToString(ModelSourceType.RegisteredModels), + icon: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') } + }; + + let components: azdata.RadioCard[] = []; this._options.forEach(option => { switch (option) { @@ -110,29 +91,95 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone break; } }); - this._flexContainer = modelBuilder.flexContainer() - .withLayout({ - flexFlow: 'row', - justifyContent: 'space-between' - }).withItems(components).component(); + let radioCardGroup = modelBuilder.radioCardGroup() + .withProperties({ + cards: components, + iconHeight: '100px', + iconWidth: '100px', + cardWidth: '170px', + cardHeight: '170px', + ariaLabel: 'test', + selectedCardId: selectedCardId + }).component(); + this._flexContainer = modelBuilder.flexContainer().withLayout({ + flexFlow: 'column' + }).withItems([radioCardGroup]).component(); + this._selectedSourceLabel = modelBuilder.text().withProperties({ + value: this.getSourceTypeDescription(this._sourceType), + CSSStyles: { + 'font-size': '14px', + 'margin': '0', + 'width': '438px' + } + }).component(); + + this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => { + this._sourceType = this.convertSourceIdToEnum(cardId); + if (this._selectedSourceLabel) { + this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType); + } + this.sendRequest(SourceModelSelectedEventName, this._sourceType); + })); this._form = modelBuilder.formContainer().withFormItems([{ title: '', component: this._flexContainer + }, { + title: '', + component: this._selectedSourceLabel }]).component(); return this._form; } + private convertSourceIdToString(sourceId: ModelSourceType): string { + return sourceId.toString(); + } + + private convertSourceIdToEnum(sourceId: string): ModelSourceType { + switch (sourceId) { + case ModelSourceType.Local.toString(): + return ModelSourceType.Local; + case ModelSourceType.Azure.toString(): + return ModelSourceType.Azure; + case ModelSourceType.RegisteredModels.toString(): + return ModelSourceType.RegisteredModels; + } + return this._defaultSourceType; + } + + private getSourceTypeDescription(sourceId: ModelSourceType): string { + if (this.modelActionType === ModelActionType.Import) { + switch (sourceId) { + case ModelSourceType.Local: + return constants.localModelSourceDescriptionForImport; + case ModelSourceType.Azure: + return constants.azureModelSourceDescriptionForImport; + } + } else if (this.modelActionType === ModelActionType.Predict) { + switch (sourceId) { + case ModelSourceType.Local: + return constants.localModelSourceDescriptionForPredict; + case ModelSourceType.Azure: + return constants.azureModelSourceDescriptionForPredict; + case ModelSourceType.RegisteredModels: + return constants.importedModelSourceDescriptionForPredict; + } + } + return ''; + } + public addComponents(formBuilder: azdata.FormBuilder) { - if (this._flexContainer) { + if (this._flexContainer && this._selectedSourceLabel) { formBuilder.addFormItem({ title: '', component: this._flexContainer }); + formBuilder.addFormItem({ title: '', component: this._selectedSourceLabel }); } } public removeComponents(formBuilder: azdata.FormBuilder) { - if (this._flexContainer) { + if (this._flexContainer && this._selectedSourceLabel) { formBuilder.removeFormItem({ title: '', component: this._flexContainer }); + formBuilder.removeFormItem({ title: '', component: this._selectedSourceLabel }); } } diff --git a/extensions/machine-learning/src/views/models/modelViewBase.ts b/extensions/machine-learning/src/views/models/modelViewBase.ts index cc91716eca..b10c3b5252 100644 --- a/extensions/machine-learning/src/views/models/modelViewBase.ts +++ b/extensions/machine-learning/src/views/models/modelViewBase.ts @@ -28,9 +28,14 @@ export interface PredictModelEventArgs extends PredictParameters { export enum ModelSourceType { - Local, - Azure, - RegisteredModels + Local = 'Local', + Azure = 'Azure', + RegisteredModels = 'RegisteredModels' +} + +export enum ModelActionType { + Import, + Predict } export interface ModelViewData { @@ -74,6 +79,7 @@ export abstract class ModelViewBase extends ViewBase { private _modelSourceType: ModelSourceType = ModelSourceType.Local; private _modelsViewData: ModelViewData[] = []; private _importTable: DatabaseTable | undefined; + private _modelActionType: ModelActionType = ModelActionType.Import; constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) { super(apiWrapper, root, parent); @@ -245,6 +251,28 @@ export abstract class ModelViewBase extends ViewBase { return await this.sendDataRequest(ListGroupsEventName, args); } + /** + * Sets model action type + */ + public set modelActionType(value: ModelActionType) { + if (this.parent) { + this.parent.modelActionType = value; + } else { + this._modelActionType = value; + } + } + + /** + * Returns model action type + */ + public get modelActionType(): ModelActionType { + if (this.parent) { + return this.parent.modelActionType; + } else { + return this._modelActionType; + } + } + /** * Sets model source type */ diff --git a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts index 10fc50a876..5ca9c4d778 100644 --- a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts +++ b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts @@ -4,7 +4,7 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { ModelViewBase, ModelSourceType } from '../modelViewBase'; +import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; import { ModelSourcesComponent } from '../modelSourcesComponent'; import { LocalModelsComponent } from '../localModelsComponent'; @@ -34,6 +34,7 @@ export class PredictWizard extends ModelViewBase { parent?: ModelViewBase) { super(apiWrapper, root); this._parentView = parent; + this.modelActionType = ModelActionType.Predict; } /** diff --git a/extensions/machine-learning/src/views/viewBase.ts b/extensions/machine-learning/src/views/viewBase.ts index 36bf18043f..a45cba484d 100644 --- a/extensions/machine-learning/src/views/viewBase.ts +++ b/extensions/machine-learning/src/views/viewBase.ts @@ -24,6 +24,7 @@ export const LocalPathsEventName = 'localPaths'; * Base class for views */ export abstract class ViewBase extends EventEmitterCollection { + protected _toDispose: vscode.Disposable[] = []; protected _mainViewPanel: azdata.window.Dialog | azdata.window.Wizard | undefined; public viewPanel: azdata.window.ModelViewPanel | undefined; public connection: azdata.connection.ConnectionProfile | undefined; @@ -197,4 +198,9 @@ export abstract class ViewBase extends EventEmitterCollection { } public abstract refresh(): Promise; + + public dispose(): void { + super.dispose(); + this._toDispose.forEach(disposable => disposable.dispose()); + } }