ML - model source type page (#13077)

* Initial checkin

* Style adjustments.

* addressed PR comments

Co-authored-by: Hale Rankin <harankin@microsoft.com>
This commit is contained in:
Leila Lali
2020-10-28 11:57:59 -07:00
committed by GitHub
parent 429d8fe584
commit 5c474d8614
6 changed files with 163 additions and 75 deletions

View File

@@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
@@ -34,6 +34,7 @@ export class ImportModelWizard extends ModelViewBase {
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
this.modelActionType = ModelActionType.Import;
}
/**
@@ -49,7 +50,7 @@ export class ImportModelWizard extends ModelViewBase {
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage, this.modelImportTargetPage]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.azureRegisterModel;
wizard.doneButton.label = constants.importModelDoneButton;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {

View File

@@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase';
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType, ModelActionType } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
@@ -16,10 +16,12 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
private _form: azdata.FormContainer | undefined;
private _flexContainer: azdata.FlexContainer | undefined;
private _amlModel: azdata.CardComponent | undefined;
private _localModel: azdata.CardComponent | undefined;
private _registeredModels: azdata.CardComponent | undefined;
private _amlModel: azdata.RadioCard | undefined;
private _localModel: azdata.RadioCard | undefined;
private _registeredModels: azdata.RadioCard | undefined;
private _sourceType: ModelSourceType = ModelSourceType.Local;
private _defaultSourceType = ModelSourceType.Local;
private _selectedSourceLabel: azdata.TextComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
@@ -31,65 +33,44 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._sourceType = this._options && this._options.length > 0 ? this._options[0] : ModelSourceType.Local;
this._sourceType = this._options && this._options.length > 0 ? this._options[0] : this._defaultSourceType;
this.modelSourceType = this._sourceType;
this._localModel = modelBuilder.card()
.withProperties({
value: 'local',
name: 'modelLocation',
label: constants.localModelSource,
selected: this._sourceType === ModelSourceType.Local,
cardType: azdata.CardType.VerticalButton,
iconPath: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') },
width: 50
}).component();
this._amlModel = modelBuilder.card()
.withProperties({
value: 'aml',
name: 'modelLocation',
label: constants.azureModelSource,
selected: this._sourceType === ModelSourceType.Azure,
cardType: azdata.CardType.VerticalButton,
iconPath: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') },
width: 50
}).component();
let selectedCardId: string = this.convertSourceIdToString(this._sourceType);
this._registeredModels = modelBuilder.card()
.withProperties({
value: 'registered',
name: 'modelLocation',
label: constants.registeredModelsSource,
selected: this._sourceType === ModelSourceType.RegisteredModels,
cardType: azdata.CardType.VerticalButton,
iconPath: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') },
width: 50
}).component();
this._localModel = {
descriptions: [{
textValue: constants.localModelSource,
textStyles: {
'font-size': '14px'
}
}],
id: this.convertSourceIdToString(ModelSourceType.Local),
icon: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') }
};
this._amlModel = {
descriptions: [{
textValue: constants.azureModelSource,
textStyles: {
'font-size': '14px'
}
}],
this._localModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Local;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._amlModel && this._registeredModels) {
this._amlModel.selected = false;
this._registeredModels.selected = false;
}
});
this._amlModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Azure;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._registeredModels) {
this._localModel.selected = false;
this._registeredModels.selected = false;
}
});
this._registeredModels.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.RegisteredModels;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._amlModel) {
this._localModel.selected = false;
this._amlModel.selected = false;
}
});
let components: azdata.Component[] = [];
id: this.convertSourceIdToString(ModelSourceType.Azure),
icon: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') }
};
this._registeredModels = {
descriptions: [{
textValue: constants.registeredModelsSource,
textStyles: {
'font-size': '14px'
}
}],
id: this.convertSourceIdToString(ModelSourceType.RegisteredModels),
icon: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') }
};
let components: azdata.RadioCard[] = [];
this._options.forEach(option => {
switch (option) {
@@ -110,29 +91,95 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
break;
}
});
this._flexContainer = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
}).withItems(components).component();
let radioCardGroup = modelBuilder.radioCardGroup()
.withProperties({
cards: components,
iconHeight: '100px',
iconWidth: '100px',
cardWidth: '170px',
cardHeight: '170px',
ariaLabel: 'test',
selectedCardId: selectedCardId
}).component();
this._flexContainer = modelBuilder.flexContainer().withLayout({
flexFlow: 'column'
}).withItems([radioCardGroup]).component();
this._selectedSourceLabel = modelBuilder.text().withProperties({
value: this.getSourceTypeDescription(this._sourceType),
CSSStyles: {
'font-size': '14px',
'margin': '0',
'width': '438px'
}
}).component();
this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => {
this._sourceType = this.convertSourceIdToEnum(cardId);
if (this._selectedSourceLabel) {
this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
}
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
}));
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._flexContainer
}, {
title: '',
component: this._selectedSourceLabel
}]).component();
return this._form;
}
private convertSourceIdToString(sourceId: ModelSourceType): string {
return sourceId.toString();
}
private convertSourceIdToEnum(sourceId: string): ModelSourceType {
switch (sourceId) {
case ModelSourceType.Local.toString():
return ModelSourceType.Local;
case ModelSourceType.Azure.toString():
return ModelSourceType.Azure;
case ModelSourceType.RegisteredModels.toString():
return ModelSourceType.RegisteredModels;
}
return this._defaultSourceType;
}
private getSourceTypeDescription(sourceId: ModelSourceType): string {
if (this.modelActionType === ModelActionType.Import) {
switch (sourceId) {
case ModelSourceType.Local:
return constants.localModelSourceDescriptionForImport;
case ModelSourceType.Azure:
return constants.azureModelSourceDescriptionForImport;
}
} else if (this.modelActionType === ModelActionType.Predict) {
switch (sourceId) {
case ModelSourceType.Local:
return constants.localModelSourceDescriptionForPredict;
case ModelSourceType.Azure:
return constants.azureModelSourceDescriptionForPredict;
case ModelSourceType.RegisteredModels:
return constants.importedModelSourceDescriptionForPredict;
}
}
return '';
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
if (this._flexContainer && this._selectedSourceLabel) {
formBuilder.addFormItem({ title: '', component: this._flexContainer });
formBuilder.addFormItem({ title: '', component: this._selectedSourceLabel });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
if (this._flexContainer && this._selectedSourceLabel) {
formBuilder.removeFormItem({ title: '', component: this._flexContainer });
formBuilder.removeFormItem({ title: '', component: this._selectedSourceLabel });
}
}

View File

@@ -28,9 +28,14 @@ export interface PredictModelEventArgs extends PredictParameters {
export enum ModelSourceType {
Local,
Azure,
RegisteredModels
Local = 'Local',
Azure = 'Azure',
RegisteredModels = 'RegisteredModels'
}
export enum ModelActionType {
Import,
Predict
}
export interface ModelViewData {
@@ -74,6 +79,7 @@ export abstract class ModelViewBase extends ViewBase {
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
private _modelsViewData: ModelViewData[] = [];
private _importTable: DatabaseTable | undefined;
private _modelActionType: ModelActionType = ModelActionType.Import;
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
super(apiWrapper, root, parent);
@@ -245,6 +251,28 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(ListGroupsEventName, args);
}
/**
* Sets model action type
*/
public set modelActionType(value: ModelActionType) {
if (this.parent) {
this.parent.modelActionType = value;
} else {
this._modelActionType = value;
}
}
/**
* Returns model action type
*/
public get modelActionType(): ModelActionType {
if (this.parent) {
return this.parent.modelActionType;
} else {
return this._modelActionType;
}
}
/**
* Sets model source type
*/

View File

@@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
@@ -34,6 +34,7 @@ export class PredictWizard extends ModelViewBase {
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
this.modelActionType = ModelActionType.Predict;
}
/**