mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
ML - model source type page (#13077)
* Initial checkin * Style adjustments. * addressed PR comments Co-authored-by: Hale Rankin <harankin@microsoft.com>
This commit is contained in:
@@ -194,7 +194,12 @@ export const azureModels = localize('models.azureModels', "Models");
|
||||
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
|
||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?");
|
||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Select model source type");
|
||||
export const localModelSourceDescriptionForImport = localize('models.localModelSourceDescriptionForImport', "‘File Upload’ is selected. This allows you to import a model file from your local machine into a model database in this SQL instance. Click ‘Next’ to continue.");
|
||||
export const azureModelSourceDescriptionForImport = localize('models.azureModelSourceDescriptionForImport', "‘Azure Machine Learning’ is selected. This allows you to import models stored in Azure Machine Learning workspaces in a model database in this SQL instance. Click ‘Next’ to continue.");
|
||||
export const localModelSourceDescriptionForPredict = localize('models.localModelSourceDescriptionForPredict', "‘File Upload’ is selected. This allows you to upload a model file from your local machine. Click ‘Next’ to continue.");
|
||||
export const importedModelSourceDescriptionForPredict = localize('models.importedModelSourceDescriptionForPredict', "‘Imported Models’ is selected. This allows you to choose from models stored in a model table in your database. Click ‘Next’ to continue.");
|
||||
export const azureModelSourceDescriptionForPredict = localize('models.azureModelSourceDescriptionForPredict', "‘Azure Machine Learning’ is selected. This allows you to choose from models stored in Azure Machine Learning workspaces. Click ‘Next’ to continue.");
|
||||
export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Select or enter the location to import the models to");
|
||||
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map source data to model");
|
||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details");
|
||||
@@ -202,7 +207,7 @@ export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "S
|
||||
export const modelLocalSourceTooltip = localize('models.modelLocalSourceTooltip', "File paths of the models to import");
|
||||
export const onnxNotSupportedError = localize('models.onnxNotSupportedError', "ONNX runtime is not supported in current server");
|
||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
|
||||
export const importModelDoneButton = localize('models.importModelDoneButton', "Import");
|
||||
export const predictModel = localize('models.predictModel', "Predict");
|
||||
export const registerModelTitle = localize('models.RegisterWizard', "Import models");
|
||||
export const importedModelTitle = localize('models.importedModelTitle', "Imported models");
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
|
||||
import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { ModelSourcesComponent } from '../modelSourcesComponent';
|
||||
import { LocalModelsComponent } from '../localModelsComponent';
|
||||
@@ -34,6 +34,7 @@ export class ImportModelWizard extends ModelViewBase {
|
||||
parent?: ModelViewBase) {
|
||||
super(apiWrapper, root);
|
||||
this._parentView = parent;
|
||||
this.modelActionType = ModelActionType.Import;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -49,7 +50,7 @@ export class ImportModelWizard extends ModelViewBase {
|
||||
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage, this.modelImportTargetPage]);
|
||||
|
||||
this.mainViewPanel = wizard;
|
||||
wizard.doneButton.label = constants.azureRegisterModel;
|
||||
wizard.doneButton.label = constants.importModelDoneButton;
|
||||
wizard.generateScriptButton.hidden = true;
|
||||
wizard.displayPageTitles = true;
|
||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase';
|
||||
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType, ModelActionType } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
@@ -16,10 +16,12 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _flexContainer: azdata.FlexContainer | undefined;
|
||||
private _amlModel: azdata.CardComponent | undefined;
|
||||
private _localModel: azdata.CardComponent | undefined;
|
||||
private _registeredModels: azdata.CardComponent | undefined;
|
||||
private _amlModel: azdata.RadioCard | undefined;
|
||||
private _localModel: azdata.RadioCard | undefined;
|
||||
private _registeredModels: azdata.RadioCard | undefined;
|
||||
private _sourceType: ModelSourceType = ModelSourceType.Local;
|
||||
private _defaultSourceType = ModelSourceType.Local;
|
||||
private _selectedSourceLabel: azdata.TextComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
@@ -31,65 +33,44 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._sourceType = this._options && this._options.length > 0 ? this._options[0] : ModelSourceType.Local;
|
||||
this._sourceType = this._options && this._options.length > 0 ? this._options[0] : this._defaultSourceType;
|
||||
this.modelSourceType = this._sourceType;
|
||||
this._localModel = modelBuilder.card()
|
||||
.withProperties({
|
||||
value: 'local',
|
||||
name: 'modelLocation',
|
||||
label: constants.localModelSource,
|
||||
selected: this._sourceType === ModelSourceType.Local,
|
||||
cardType: azdata.CardType.VerticalButton,
|
||||
iconPath: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') },
|
||||
width: 50
|
||||
}).component();
|
||||
this._amlModel = modelBuilder.card()
|
||||
.withProperties({
|
||||
value: 'aml',
|
||||
name: 'modelLocation',
|
||||
label: constants.azureModelSource,
|
||||
selected: this._sourceType === ModelSourceType.Azure,
|
||||
cardType: azdata.CardType.VerticalButton,
|
||||
iconPath: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') },
|
||||
width: 50
|
||||
}).component();
|
||||
let selectedCardId: string = this.convertSourceIdToString(this._sourceType);
|
||||
|
||||
this._registeredModels = modelBuilder.card()
|
||||
.withProperties({
|
||||
value: 'registered',
|
||||
name: 'modelLocation',
|
||||
label: constants.registeredModelsSource,
|
||||
selected: this._sourceType === ModelSourceType.RegisteredModels,
|
||||
cardType: azdata.CardType.VerticalButton,
|
||||
iconPath: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') },
|
||||
width: 50
|
||||
}).component();
|
||||
this._localModel = {
|
||||
descriptions: [{
|
||||
textValue: constants.localModelSource,
|
||||
textStyles: {
|
||||
'font-size': '14px'
|
||||
}
|
||||
}],
|
||||
id: this.convertSourceIdToString(ModelSourceType.Local),
|
||||
icon: { light: this.asAbsolutePath('images/fileUpload.svg'), dark: this.asAbsolutePath('images/fileUpload.svg') }
|
||||
};
|
||||
this._amlModel = {
|
||||
descriptions: [{
|
||||
textValue: constants.azureModelSource,
|
||||
textStyles: {
|
||||
'font-size': '14px'
|
||||
}
|
||||
}],
|
||||
|
||||
this._localModel.onCardSelectedChanged(() => {
|
||||
this._sourceType = ModelSourceType.Local;
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
if (this._amlModel && this._registeredModels) {
|
||||
this._amlModel.selected = false;
|
||||
this._registeredModels.selected = false;
|
||||
}
|
||||
});
|
||||
this._amlModel.onCardSelectedChanged(() => {
|
||||
this._sourceType = ModelSourceType.Azure;
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
if (this._localModel && this._registeredModels) {
|
||||
this._localModel.selected = false;
|
||||
this._registeredModels.selected = false;
|
||||
}
|
||||
});
|
||||
this._registeredModels.onCardSelectedChanged(() => {
|
||||
this._sourceType = ModelSourceType.RegisteredModels;
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
if (this._localModel && this._amlModel) {
|
||||
this._localModel.selected = false;
|
||||
this._amlModel.selected = false;
|
||||
}
|
||||
});
|
||||
let components: azdata.Component[] = [];
|
||||
id: this.convertSourceIdToString(ModelSourceType.Azure),
|
||||
icon: { light: this.asAbsolutePath('images/aml.svg'), dark: this.asAbsolutePath('images/aml.svg') }
|
||||
};
|
||||
|
||||
this._registeredModels = {
|
||||
descriptions: [{
|
||||
textValue: constants.registeredModelsSource,
|
||||
textStyles: {
|
||||
'font-size': '14px'
|
||||
}
|
||||
}],
|
||||
id: this.convertSourceIdToString(ModelSourceType.RegisteredModels),
|
||||
icon: { light: this.asAbsolutePath('images/imported.svg'), dark: this.asAbsolutePath('images/imported.svg') }
|
||||
};
|
||||
|
||||
let components: azdata.RadioCard[] = [];
|
||||
|
||||
this._options.forEach(option => {
|
||||
switch (option) {
|
||||
@@ -110,29 +91,95 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
||||
break;
|
||||
}
|
||||
});
|
||||
this._flexContainer = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between'
|
||||
}).withItems(components).component();
|
||||
let radioCardGroup = modelBuilder.radioCardGroup()
|
||||
.withProperties({
|
||||
cards: components,
|
||||
iconHeight: '100px',
|
||||
iconWidth: '100px',
|
||||
cardWidth: '170px',
|
||||
cardHeight: '170px',
|
||||
ariaLabel: 'test',
|
||||
selectedCardId: selectedCardId
|
||||
}).component();
|
||||
this._flexContainer = modelBuilder.flexContainer().withLayout({
|
||||
flexFlow: 'column'
|
||||
}).withItems([radioCardGroup]).component();
|
||||
this._selectedSourceLabel = modelBuilder.text().withProperties({
|
||||
value: this.getSourceTypeDescription(this._sourceType),
|
||||
CSSStyles: {
|
||||
'font-size': '14px',
|
||||
'margin': '0',
|
||||
'width': '438px'
|
||||
}
|
||||
}).component();
|
||||
|
||||
this._toDispose.push(radioCardGroup.onSelectionChanged(({ cardId }) => {
|
||||
this._sourceType = this.convertSourceIdToEnum(cardId);
|
||||
if (this._selectedSourceLabel) {
|
||||
this._selectedSourceLabel.value = this.getSourceTypeDescription(this._sourceType);
|
||||
}
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
}));
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
component: this._flexContainer
|
||||
}, {
|
||||
title: '',
|
||||
component: this._selectedSourceLabel
|
||||
}]).component();
|
||||
|
||||
return this._form;
|
||||
}
|
||||
|
||||
private convertSourceIdToString(sourceId: ModelSourceType): string {
|
||||
return sourceId.toString();
|
||||
}
|
||||
|
||||
private convertSourceIdToEnum(sourceId: string): ModelSourceType {
|
||||
switch (sourceId) {
|
||||
case ModelSourceType.Local.toString():
|
||||
return ModelSourceType.Local;
|
||||
case ModelSourceType.Azure.toString():
|
||||
return ModelSourceType.Azure;
|
||||
case ModelSourceType.RegisteredModels.toString():
|
||||
return ModelSourceType.RegisteredModels;
|
||||
}
|
||||
return this._defaultSourceType;
|
||||
}
|
||||
|
||||
private getSourceTypeDescription(sourceId: ModelSourceType): string {
|
||||
if (this.modelActionType === ModelActionType.Import) {
|
||||
switch (sourceId) {
|
||||
case ModelSourceType.Local:
|
||||
return constants.localModelSourceDescriptionForImport;
|
||||
case ModelSourceType.Azure:
|
||||
return constants.azureModelSourceDescriptionForImport;
|
||||
}
|
||||
} else if (this.modelActionType === ModelActionType.Predict) {
|
||||
switch (sourceId) {
|
||||
case ModelSourceType.Local:
|
||||
return constants.localModelSourceDescriptionForPredict;
|
||||
case ModelSourceType.Azure:
|
||||
return constants.azureModelSourceDescriptionForPredict;
|
||||
case ModelSourceType.RegisteredModels:
|
||||
return constants.importedModelSourceDescriptionForPredict;
|
||||
}
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flexContainer) {
|
||||
if (this._flexContainer && this._selectedSourceLabel) {
|
||||
formBuilder.addFormItem({ title: '', component: this._flexContainer });
|
||||
formBuilder.addFormItem({ title: '', component: this._selectedSourceLabel });
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flexContainer) {
|
||||
if (this._flexContainer && this._selectedSourceLabel) {
|
||||
formBuilder.removeFormItem({ title: '', component: this._flexContainer });
|
||||
formBuilder.removeFormItem({ title: '', component: this._selectedSourceLabel });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,9 +28,14 @@ export interface PredictModelEventArgs extends PredictParameters {
|
||||
|
||||
|
||||
export enum ModelSourceType {
|
||||
Local,
|
||||
Azure,
|
||||
RegisteredModels
|
||||
Local = 'Local',
|
||||
Azure = 'Azure',
|
||||
RegisteredModels = 'RegisteredModels'
|
||||
}
|
||||
|
||||
export enum ModelActionType {
|
||||
Import,
|
||||
Predict
|
||||
}
|
||||
|
||||
export interface ModelViewData {
|
||||
@@ -74,6 +79,7 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
|
||||
private _modelsViewData: ModelViewData[] = [];
|
||||
private _importTable: DatabaseTable | undefined;
|
||||
private _modelActionType: ModelActionType = ModelActionType.Import;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
|
||||
super(apiWrapper, root, parent);
|
||||
@@ -245,6 +251,28 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
return await this.sendDataRequest(ListGroupsEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets model action type
|
||||
*/
|
||||
public set modelActionType(value: ModelActionType) {
|
||||
if (this.parent) {
|
||||
this.parent.modelActionType = value;
|
||||
} else {
|
||||
this._modelActionType = value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns model action type
|
||||
*/
|
||||
public get modelActionType(): ModelActionType {
|
||||
if (this.parent) {
|
||||
return this.parent.modelActionType;
|
||||
} else {
|
||||
return this._modelActionType;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets model source type
|
||||
*/
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
|
||||
import { ModelViewBase, ModelSourceType, ModelActionType } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { ModelSourcesComponent } from '../modelSourcesComponent';
|
||||
import { LocalModelsComponent } from '../localModelsComponent';
|
||||
@@ -34,6 +34,7 @@ export class PredictWizard extends ModelViewBase {
|
||||
parent?: ModelViewBase) {
|
||||
super(apiWrapper, root);
|
||||
this._parentView = parent;
|
||||
this.modelActionType = ModelActionType.Predict;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -24,6 +24,7 @@ export const LocalPathsEventName = 'localPaths';
|
||||
* Base class for views
|
||||
*/
|
||||
export abstract class ViewBase extends EventEmitterCollection {
|
||||
protected _toDispose: vscode.Disposable[] = [];
|
||||
protected _mainViewPanel: azdata.window.Dialog | azdata.window.Wizard | undefined;
|
||||
public viewPanel: azdata.window.ModelViewPanel | undefined;
|
||||
public connection: azdata.connection.ConnectionProfile | undefined;
|
||||
@@ -197,4 +198,9 @@ export abstract class ViewBase extends EventEmitterCollection {
|
||||
}
|
||||
|
||||
public abstract refresh(): Promise<void>;
|
||||
|
||||
public dispose(): void {
|
||||
super.dispose();
|
||||
this._toDispose.forEach(disposable => disposable.dispose());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user