ML - model source type page (#13077)

* Initial checkin

* Style adjustments.

* addressed PR comments

Co-authored-by: Hale Rankin <harankin@microsoft.com>
This commit is contained in:
Leila Lali
2020-10-28 11:57:59 -07:00
committed by GitHub
parent 429d8fe584
commit 5c474d8614
6 changed files with 163 additions and 75 deletions

View File

@@ -194,7 +194,12 @@ export const azureModels = localize('models.azureModels', "Models");
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?");
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Select model source type");
export const localModelSourceDescriptionForImport = localize('models.localModelSourceDescriptionForImport', "File Upload is selected. This allows you to import a model file from your local machine into a model database in this SQL instance. Click Next to continue.");
export const azureModelSourceDescriptionForImport = localize('models.azureModelSourceDescriptionForImport', "Azure Machine Learning is selected. This allows you to import models stored in Azure Machine Learning workspaces in a model database in this SQL instance. Click Next to continue.");
export const localModelSourceDescriptionForPredict = localize('models.localModelSourceDescriptionForPredict', "File Upload is selected. This allows you to upload a model file from your local machine. Click Next to continue.");
export const importedModelSourceDescriptionForPredict = localize('models.importedModelSourceDescriptionForPredict', "Imported Models is selected. This allows you to choose from models stored in a model table in your database. Click Next to continue.");
export const azureModelSourceDescriptionForPredict = localize('models.azureModelSourceDescriptionForPredict', "Azure Machine Learning is selected. This allows you to choose from models stored in Azure Machine Learning workspaces. Click Next to continue.");
export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Select or enter the location to import the models to");
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map source data to model");
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details");
@@ -202,7 +207,7 @@ export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "S
export const modelLocalSourceTooltip = localize('models.modelLocalSourceTooltip', "File paths of the models to import");
export const onnxNotSupportedError = localize('models.onnxNotSupportedError', "ONNX runtime is not supported in current server");
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
export const importModelDoneButton = localize('models.importModelDoneButton', "Import");
export const predictModel = localize('models.predictModel', "Predict");
export const registerModelTitle = localize('models.RegisterWizard', "Import models");
export const importedModelTitle = localize('models.importedModelTitle', "Imported models");

View File

@@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
@@ -34,6 +34,7 @@ export class ImportModelWizard extends ModelViewBase {
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
this.modelActionType = ModelActionType.Import;
}
/**
@@ -49,7 +50,7 @@ export class ImportModelWizard extends ModelViewBase {
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage, this.modelImportTargetPage]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.azureRegisterModel;
wizard.doneButton.label = constants.importModelDoneButton;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {

View File

@@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase';
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType, ModelActionType } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
@@ -16,10 +16,12 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
private _form: azdata.FormContainer | undefined;
private _flexContainer: azdata.FlexContainer | undefined;
private _amlModel: azdata.CardComponent | undefined;
private _localModel: azdata.CardComponent | undefined;
private _registeredModels: azdata.CardComponent | undefined;
private _amlModel: azdata.RadioCard | undefined;
private _localModel: azdata.RadioCard | undefined;
private _registeredModels: azdata.RadioCard | undefined;
private _sourceType: ModelSourceType = ModelSourceType.Local;
private _defaultSourceType = ModelSourceType.Local;
private _selectedSourceLabel: azdata.TextComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
@@ -31,65 +33,44 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._sourceType = this._options && this._options.length > 0 ? this._options[0] : ModelSourceType.Local;
this._sourceType = this._options && this._options.length > 0 ? this._options[0] : this._defaultSourceType;
this.modelSourceType = this._sourceType;
this._localModel = modelBuilder.card()
.withProperties({
value: 'local',
name: 'modelLocation',
label: constants.localModelSource,
selected: this._sourceType === ModelSourceType.Local,
cardType: azdata.CardType.VerticalButton,
iconPath: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') },
width: 50
}).component();
this._amlModel = modelBuilder.card()
.withProperties({
value: 'aml',
name: 'modelLocation',
label: constants.azureModelSource,
selected: this._sourceType === ModelSourceType.Azure,
cardType: azdata.CardType.VerticalButton,
iconPath: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') },
width: 50
}).component();
let selectedCardId: string = this.convertSourceIdToString(this._sourceType);
this._registeredModels = modelBuilder.card()
.withProperties({
value: 'registered',
name: 'modelLocation',
label: constants.registeredModelsSource,
selected: this._sourceType === ModelSourceType.RegisteredModels,
cardType: azdata.CardType.VerticalButton,
iconPath: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') },
width: 50
}).component();
this._localModel = {
descriptions: [{
textValue: constants.localModelSource,
textStyles: {
'font-size': '14px'
}
}],
id: this.convertSourceIdToString(ModelSourceType.Local),
icon: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') }
};
this._amlModel = {
descriptions: [{
textValue: constants.azureModelSource,
textStyles: {
'font-size': '14px'
}
}],
this._localModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Local;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._amlModel && this._registeredModels) {
this._amlModel.selected = false;
this._registeredModels.selected = false;
}
});
this._amlModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Azure;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._registeredModels) {
this._localModel.selected = false;
this._registeredModels.selected = false;
}
});
this._registeredModels.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.RegisteredModels;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._amlModel) {
this._localModel.selected = false;
this._amlModel.selected = false;
}
});
let components: azdata.Component[] = [];
id: this.convertSourceIdToString(ModelSourceType.Azure),
icon: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') }
};
this._registeredModels = {
descriptions: [{
textValue: constants.registeredModelsSource,
textStyles: {
'font-size': '14px'
}
}],
id: this.convertSourceIdToString(ModelSourceType.RegisteredModels),
icon: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') }
};
let components: azdata.RadioCard[] = [];
this._options.forEach(option => {
switch (option) {
@@ -110,29 +91,95 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
break;
}
});
this._flexContainer = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
}).withItems(components).component();
let radioCardGroup = modelBuilder.radioCardGroup()
.withProperties({
cards: components,
iconHeight: '100px',
iconWidth: '100px',
cardWidth: '170px',
cardHeight: '170px',
ariaLabel: 'test',
selectedCardId: selectedCardId
}).component();
this._flexContainer = modelBuilder.flexContainer().withLayout({
flexFlow: 'column'
}).withItems([radioCardGroup]).component();
this._selectedSourceLabel = modelBuilder.text().withProperties({
value: this.getSourceTypeDescription(this._sourceType),
CSSStyles: {
'font-size': '14px',
'margin': '0',
'width': '438px'
}
}).component();
this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => {
this._sourceType = this.convertSourceIdToEnum(cardId);
if (this._selectedSourceLabel) {
this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
}
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
}));
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._flexContainer
}, {
title: '',
component: this._selectedSourceLabel
}]).component();
return this._form;
}
private convertSourceIdToString(sourceId: ModelSourceType): string {
return sourceId.toString();
}
private convertSourceIdToEnum(sourceId: string): ModelSourceType {
switch (sourceId) {
case ModelSourceType.Local.toString():
return ModelSourceType.Local;
case ModelSourceType.Azure.toString():
return ModelSourceType.Azure;
case ModelSourceType.RegisteredModels.toString():
return ModelSourceType.RegisteredModels;
}
return this._defaultSourceType;
}
private getSourceTypeDescription(sourceId: ModelSourceType): string {
if (this.modelActionType === ModelActionType.Import) {
switch (sourceId) {
case ModelSourceType.Local:
return constants.localModelSourceDescriptionForImport;
case ModelSourceType.Azure:
return constants.azureModelSourceDescriptionForImport;
}
} else if (this.modelActionType === ModelActionType.Predict) {
switch (sourceId) {
case ModelSourceType.Local:
return constants.localModelSourceDescriptionForPredict;
case ModelSourceType.Azure:
return constants.azureModelSourceDescriptionForPredict;
case ModelSourceType.RegisteredModels:
return constants.importedModelSourceDescriptionForPredict;
}
}
return '';
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
if (this._flexContainer && this._selectedSourceLabel) {
formBuilder.addFormItem({ title: '', component: this._flexContainer });
formBuilder.addFormItem({ title: '', component: this._selectedSourceLabel });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
if (this._flexContainer && this._selectedSourceLabel) {
formBuilder.removeFormItem({ title: '', component: this._flexContainer });
formBuilder.removeFormItem({ title: '', component: this._selectedSourceLabel });
}
}

View File

@@ -28,9 +28,14 @@ export interface PredictModelEventArgs extends PredictParameters {
export enum ModelSourceType {
Local,
Azure,
RegisteredModels
Local = 'Local',
Azure = 'Azure',
RegisteredModels = 'RegisteredModels'
}
export enum ModelActionType {
Import,
Predict
}
export interface ModelViewData {
@@ -74,6 +79,7 @@ export abstract class ModelViewBase extends ViewBase {
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
private _modelsViewData: ModelViewData[] = [];
private _importTable: DatabaseTable | undefined;
private _modelActionType: ModelActionType = ModelActionType.Import;
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
super(apiWrapper, root, parent);
@@ -245,6 +251,28 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(ListGroupsEventName, args);
}
/**
* Sets model action type
*/
public set modelActionType(value: ModelActionType) {
if (this.parent) {
this.parent.modelActionType = value;
} else {
this._modelActionType = value;
}
}
/**
* Returns model action type
*/
public get modelActionType(): ModelActionType {
if (this.parent) {
return this.parent.modelActionType;
} else {
return this._modelActionType;
}
}
/**
* Sets model source type
*/

View File

@@ -4,7 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
@@ -34,6 +34,7 @@ export class PredictWizard extends ModelViewBase {
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
this.modelActionType = ModelActionType.Predict;
}
/**

View File

@@ -24,6 +24,7 @@ export const LocalPathsEventName = 'localPaths';
* Base class for views
*/
export abstract class ViewBase extends EventEmitterCollection {
protected _toDispose: vscode.Disposable[] = [];
protected _mainViewPanel: azdata.window.Dialog | azdata.window.Wizard | undefined;
public viewPanel: azdata.window.ModelViewPanel | undefined;
public connection: azdata.connection.ConnectionProfile | undefined;
@@ -197,4 +198,9 @@ export abstract class ViewBase extends EventEmitterCollection {
}
public abstract refresh(): Promise<void>;
public dispose(): void {
super.dispose();
this._toDispose.forEach(disposable => disposable.dispose());
}
}