ML - dashboard icons and links (#10153)

* ML - dashboard icons and links
This commit is contained in:
Leila Lali
2020-04-28 21:21:30 -07:00
committed by GitHub
parent 046995f2a5
commit 04af41c424
145 changed files with 387 additions and 134 deletions

View File

@@ -0,0 +1,170 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
import { AzureModelsTable } from './azureModelsTable';
import { IDataComponent, AzureModelResource } from '../interfaces';
import { ModelArtifact } from './prediction/modelArtifact';
import { AzureSignInComponent } from './azureSignInComponent';
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource[]> {
public azureModelsTable: AzureModelsTable | undefined;
public azureFilterComponent: AzureResourceFilterComponent | undefined;
public azureSignInComponent: AzureSignInComponent | undefined;
private _loader: azdata.LoadingComponent | undefined;
private _form: azdata.FormContainer | undefined;
private _downloadedFile: ModelArtifact | undefined;
/**
* Component to render a view to pick an azure model
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this);
this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect);
this.azureSignInComponent = new AzureSignInComponent(this._apiWrapper, modelBuilder, this);
this._loader = modelBuilder.loadingComponent()
.withItem(this.azureModelsTable.component)
.withProperties({
loading: true
}).component();
this.azureModelsTable.onModelSelectionChanged(async () => {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
this._downloadedFile = undefined;
});
this.azureFilterComponent.onWorkspacesSelectedChanged(async () => {
await this.onLoading();
await this.azureModelsTable?.loadData(this.azureFilterComponent?.data);
await this.onLoaded();
});
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this.azureFilterComponent.component
}, {
title: '',
component: this._loader
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
this.removeComponents(formBuilder);
if (this.azureFilterComponent?.data?.account) {
this.addAzureComponents(formBuilder);
} else {
this.addAzureSignInComponents(formBuilder);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
this.removeAzureComponents(formBuilder);
this.removeAzureSignInComponents(formBuilder);
}
private addAzureComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.addComponents(formBuilder);
formBuilder.addFormItems([{
title: '',
component: this._loader
}]);
}
}
private removeAzureComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.removeComponents(formBuilder);
formBuilder.removeFormItem({
title: '',
component: this._loader
});
}
}
private addAzureSignInComponents(formBuilder: azdata.FormBuilder) {
if (this.azureSignInComponent) {
this.azureSignInComponent.addComponents(formBuilder);
}
}
private removeAzureSignInComponents(formBuilder: azdata.FormBuilder) {
if (this.azureSignInComponent) {
this.azureSignInComponent.removeComponents(formBuilder);
}
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
private async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Loads the data in the components
*/
public async loadData(): Promise<void> {
await this.azureFilterComponent?.loadData();
}
/**
* Returns selected data
*/
public get data(): AzureModelResource[] | undefined {
return this.azureModelsTable?.data ? this.azureModelsTable?.data.map(x => Object.assign({}, this.azureFilterComponent?.data, {
model: x
})) : undefined;
}
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
const data = this.data;
if (!this._downloadedFile && data && data.length > 0) {
this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(data[0]));
}
return this._downloadedFile;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,184 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../common/constants';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { WorkspaceModel } from '../../modelManagement/interfaces';
import { IDataComponent, AzureWorkspaceResource } from '../interfaces';
/**
* View to render azure models in a table
*/
export class AzureModelsTable extends ModelViewBase implements IDataComponent<WorkspaceModel[]> {
private _table: azdata.DeclarativeTableComponent;
private _selectedModel: WorkspaceModel[] = [];
private _models: WorkspaceModel[] | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
/**
* Creates a view to render azure models in a table
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
this._table = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Version
displayName: constants.modelVersion,
ariaLabel: constants.modelVersion,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
}
public get component(): azdata.DeclarativeTableComponent {
return this._table;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadData(workspaceResource?: AzureWorkspaceResource | undefined): Promise<void> {
if (this._table && workspaceResource) {
this._models = await this.listAzureModels(workspaceResource);
let tableData: any[][] = [];
if (this._models) {
tableData = tableData.concat(this._models.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
this._onModelSelectionChanged.fire();
}
private createTableRow(model: WorkspaceModel): any[] {
if (this._modelBuilder) {
let selectModelButton: azdata.Component;
let onSelectItem = (checked: boolean) => {
const foundItem = this._selectedModel.find(x => x === model);
if (checked && !foundItem) {
this._selectedModel.push(model);
} else if (foundItem) {
this._selectedModel = this._selectedModel.filter(x => x !== model);
}
this._onModelSelectionChanged.fire();
};
if (this._multiSelect) {
const checkbox = this._modelBuilder.checkBox().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
checkbox.onChanged(() => {
onSelectItem(checkbox.checked || false);
});
selectModelButton = checkbox;
} else {
const radioButton = this._modelBuilder.radioButton().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
radioButton.onDidClick(() => {
onSelectItem(radioButton.checked || false);
});
selectModelButton = radioButton;
}
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
}
return [];
}
/**
* Returns selected data
*/
public get data(): WorkspaceModel[] | undefined {
if (this._models && this._selectedModel) {
return this._selectedModel;
}
return undefined;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,207 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { azureResource } from '../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import * as constants from '../../common/constants';
import { AzureWorkspaceResource, IDataComponent } from '../interfaces';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 300;
export class AzureResourceFilterComponent extends ModelViewBase implements IDataComponent<AzureWorkspaceResource> {
private _form: azdata.FormContainer;
private _accounts: azdata.DropDownComponent;
private _subscriptions: azdata.DropDownComponent;
private _groups: azdata.DropDownComponent;
private _workspaces: azdata.DropDownComponent;
private _azureAccounts: azdata.Account[] = [];
private _azureSubscriptions: azureResource.AzureResourceSubscription[] = [];
private _azureGroups: azureResource.AzureResource[] = [];
private _azureWorkspaces: Workspace[] = [];
private _onWorkspacesSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onWorkspacesSelectedChanged: vscode.Event<void> = this._onWorkspacesSelectedChanged.event;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._accounts = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._subscriptions = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._groups = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._workspaces = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._accounts.onValueChanged(async () => {
await this.onAccountSelected();
});
this._subscriptions.onValueChanged(async () => {
await this.onSubscriptionSelected();
});
this._groups.onValueChanged(async () => {
await this.onGroupSelected();
});
this._workspaces.onValueChanged(async () => {
await this.onWorkspaceSelectedChanged();
});
this._form = this._modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._accounts
}, {
title: constants.azureSubscription,
component: this._subscriptions
}, {
title: constants.azureGroup,
component: this._groups
}, {
title: constants.azureModelWorkspace,
component: this._workspaces
}]).component();
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
formBuilder.addFormItems([{
title: constants.azureAccount,
component: this._accounts
}, {
title: constants.azureSubscription,
component: this._subscriptions
}, {
title: constants.azureGroup,
component: this._groups
}, {
title: constants.azureModelWorkspace,
component: this._workspaces
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._accounts
});
formBuilder.removeFormItem({
title: constants.azureSubscription,
component: this._subscriptions
});
formBuilder.removeFormItem({
title: constants.azureGroup,
component: this._groups
});
formBuilder.removeFormItem({
title: constants.azureModelWorkspace,
component: this._workspaces
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component {
return this._form;
}
/**
* Returns selected data
*/
public get data(): AzureWorkspaceResource | undefined {
return {
account: this.account,
subscription: this.subscription,
group: this.group,
workspace: this.workspace
};
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
this._azureAccounts = await this.listAzureAccounts();
if (this._azureAccounts && this._azureAccounts.length > 0) {
let values = this._azureAccounts.map(a => { return { displayName: a.displayInfo.displayName, name: a.key.accountId }; });
this._accounts.values = values;
this._accounts.value = values[0];
}
await this.onAccountSelected();
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
private async onAccountSelected(): Promise<void> {
this._azureSubscriptions = await this.listAzureSubscriptions(this.account);
if (this._azureSubscriptions && this._azureSubscriptions.length > 0) {
let values = this._azureSubscriptions.map(s => { return { displayName: s.name, name: s.id }; });
this._subscriptions.values = values;
this._subscriptions.value = values[0];
}
await this.onSubscriptionSelected();
}
private async onSubscriptionSelected(): Promise<void> {
this._azureGroups = await this.listAzureGroups(this.account, this.subscription);
if (this._azureGroups && this._azureGroups.length > 0) {
let values = this._azureGroups.map(s => { return { displayName: s.name, name: s.id }; });
this._groups.values = values;
this._groups.value = values[0];
}
await this.onGroupSelected();
}
private async onGroupSelected(): Promise<void> {
this._azureWorkspaces = await this.listWorkspaces(this.account, this.subscription, this.group);
if (this._azureWorkspaces && this._azureWorkspaces.length > 0) {
let values = this._azureWorkspaces.map(s => { return { displayName: s.name || '', name: s.id || '' }; });
this._workspaces.values = values;
this._workspaces.value = values[0];
}
this.onWorkspaceSelectedChanged();
}
private onWorkspaceSelectedChanged(): void {
this._onWorkspacesSelectedChanged.fire();
}
private get workspace(): Workspace | undefined {
return this._azureWorkspaces && this._workspaces.value ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
}
private get account(): azdata.Account | undefined {
return this._azureAccounts && this._accounts.value ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
}
private get group(): azureResource.AzureResource | undefined {
return this._azureGroups && this._groups.value ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
}
private get subscription(): azureResource.AzureResourceSubscription | undefined {
return this._azureSubscriptions && this._subscriptions.value ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
}
}

View File

@@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SignInToAzureEventName } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 300;
export class AzureSignInComponent extends ModelViewBase {
private _form: azdata.FormContainer;
private _signInButton: azdata.ButtonComponent;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._signInButton = this._modelBuilder.button().withProperties({
width: componentWidth,
label: constants.azureSignIn,
}).component();
this._signInButton.onDidClick(() => {
this.sendRequest(SignInToAzureEventName);
});
this._form = this._modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._signInButton
}]).component();
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._signInButton) {
formBuilder.addFormItems([{
title: constants.azureAccount,
component: this._signInButton
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._signInButton) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._signInButton
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component {
return this._form;
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,128 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
/**
* View to pick local models file
*/
export class LocalModelsComponent extends ModelViewBase implements IDataComponent<string[]> {
private _form: azdata.FormContainer | undefined;
private _flex: azdata.FlexContainer | undefined;
private _localPath: azdata.InputBoxComponent | undefined;
private _localBrowse: azdata.ButtonComponent | undefined;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._localPath = modelBuilder.inputBox().withProperties({
value: '',
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength
}).component();
this._localBrowse = modelBuilder.button().withProperties({
label: constants.browseModels,
width: this.browseButtonMaxLength,
CSSStyles: {
'text-align': 'end'
}
}).component();
this._localBrowse.onDidClick(async () => {
let options: vscode.OpenDialogOptions = {
canSelectFiles: true,
canSelectFolders: false,
canSelectMany: this._multiSelect,
filters: { 'ONNX File': ['onnx'] }
};
const filePaths = await this.getLocalPaths(options);
if (this._localPath && filePaths && filePaths.length > 0) {
this._localPath.value = this._multiSelect ? filePaths.join(';') : filePaths[0];
} else if (this._localPath) {
this._localPath.value = '';
}
});
this._flex = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between',
width: this.componentMaxLength
}).withItems([
this._localPath, this._localBrowse]
).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._flex
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.addFormItem({
title: '',
component: this._flex
});
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.removeFormItem({
title: '',
component: this._flex
});
}
}
/**
* Returns selected data
*/
public get data(): string[] {
if (this._localPath?.value) {
return this._localPath?.value.split(';');
} else {
return [];
}
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
/**
* Returns the page title
*/
public get title(): string {
return constants.localModelsTitle;
}
}

View File

@@ -0,0 +1,148 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { CurrentModelsTable } from './currentModelsTable';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IPageView, IComponentSettings } from '../../interfaces';
import { TableSelectionComponent } from '../tableSelectionComponent';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* View to render current registered models
*/
export class CurrentModelsComponent extends ModelViewBase implements IPageView {
private _tableComponent: azdata.Component | undefined;
private _dataTable: CurrentModelsTable | undefined;
private _loader: azdata.LoadingComponent | undefined;
private _tableSelectionComponent: TableSelectionComponent | undefined;
/**
*
* @param apiWrapper Creates new view
* @param parent page parent
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
this._tableSelectionComponent.registerComponent(modelBuilder);
this._tableSelectionComponent.onSelectedChanged(async () => {
await this.onTableSelected();
});
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings);
this._dataTable.registerComponent(modelBuilder);
this._tableComponent = this._dataTable.component;
let formModelBuilder = modelBuilder.formContainer();
this._tableSelectionComponent.addComponents(formModelBuilder);
if (this._tableComponent) {
formModelBuilder.addFormItem({
component: this._tableComponent,
title: ''
});
}
this._loader = modelBuilder.loadingComponent()
.withItem(formModelBuilder.component())
.withProperties({
loading: true
}).component();
return this._loader;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._tableSelectionComponent && this._dataTable) {
this._tableSelectionComponent.addComponents(formBuilder);
this._dataTable.addComponents(formBuilder);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._tableSelectionComponent && this._dataTable) {
this._tableSelectionComponent.removeComponents(formBuilder);
this._dataTable.removeComponents(formBuilder);
}
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._loader;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.onLoading();
try {
if (this._tableSelectionComponent) {
this._tableSelectionComponent.refresh();
}
await this._dataTable?.refresh();
} catch (err) {
this.showErrorMessage(constants.getErrorMessage(err));
} finally {
await this.onLoaded();
}
}
public get data(): ImportedModel[] | undefined {
return this._dataTable?.data;
}
private async onTableSelected(): Promise<void> {
if (this._tableSelectionComponent?.data) {
this.importTable = this._tableSelectionComponent?.data;
await this.storeImportConfigTable();
await this._dataTable?.refresh();
}
}
public get modelTable(): CurrentModelsTable | undefined {
return this._dataTable;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._dataTable) {
await this._dataTable.disposeComponent();
}
}
/**
* returns the title of the page
*/
public get title(): string {
return constants.currentModelsTitle;
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
private async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
}

View File

@@ -0,0 +1,314 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../../common/constants';
import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { IDataComponent, IComponentSettings } from '../../interfaces';
import { ModelArtifact } from '../prediction/modelArtifact';
import * as utils from '../../../common/utils';
/**
* View to render registered models table
*/
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<ImportedModel[]> {
private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined;
private _selectedModel: ImportedModel[] = [];
private _loader: azdata.LoadingComponent | undefined;
private _downloadedFile: ModelArtifact | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._modelBuilder = modelBuilder;
let columns = [
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
if (this._settings.editable) {
columns.push(
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
);
}
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: columns,
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
this._loader = modelBuilder.loadingComponent()
.withItem(this._table)
.withProperties({
loading: true
}).component();
return this._loader;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this.component) {
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this.component) {
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component });
}
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._loader;
}
/**
* Loads the data in the component
*/
public async loadData(): Promise<void> {
await this.onLoading();
if (this._table) {
let models: ImportedModel[] | undefined;
if (this.importTable) {
models = await this.listModels(this.importTable);
} else {
this.showErrorMessage('No import table');
}
let tableData: any[][] = [];
if (models) {
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
this.onModelSelected();
await this.onLoaded();
}
public async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
public async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
private createTableRow(model: ImportedModel): any[] {
let row: any[] = [model.modelName, model.created];
if (this._modelBuilder) {
const selectButton = this.createSelectButton(model);
if (selectButton) {
row.push(selectButton);
}
const editButtons = this.createEditButtons(model);
if (editButtons && editButtons.length > 0) {
row = row.concat(editButtons);
}
}
return row;
}
private createSelectButton(model: ImportedModel): azdata.Component | undefined {
let selectModelButton: azdata.Component | undefined = undefined;
if (this._modelBuilder && this._settings.selectable) {
let onSelectItem = (checked: boolean) => {
if (!this._settings.multiSelect) {
this._selectedModel = [];
}
const foundItem = this._selectedModel.find(x => x === model);
if (checked && !foundItem) {
this._selectedModel.push(model);
} else if (foundItem) {
this._selectedModel = this._selectedModel.filter(x => x !== model);
}
this.onModelSelected();
};
if (this._settings.multiSelect) {
const checkbox = this._modelBuilder.checkBox().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
checkbox.onChanged(() => {
onSelectItem(checkbox.checked || false);
});
selectModelButton = checkbox;
} else {
const radioButton = this._modelBuilder.radioButton().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
radioButton.onDidClick(() => {
onSelectItem(radioButton.checked || false);
});
selectModelButton = radioButton;
}
}
return selectModelButton;
}
private createEditButtons(model: ImportedModel): azdata.Component[] | undefined {
let dropButton: azdata.ButtonComponent | undefined = undefined;
let editButton: azdata.ButtonComponent | undefined = undefined;
if (this._modelBuilder && this._settings.editable) {
dropButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
width: 15,
height: 15
}).component();
dropButton.onDidClick(async () => {
try {
const confirm = await utils.promptConfirm(constants.confirmDeleteModel(model.modelName), this._apiWrapper);
if (confirm) {
await this.sendDataRequest(DeleteModelEventName, model);
if (this.parent) {
await this.parent?.refresh();
}
}
} catch (error) {
this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`);
}
});
editButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
width: 15,
height: 15
}).component();
editButton.onDidClick(async () => {
await this.sendDataRequest(EditModelEventName, model);
});
}
return editButton && dropButton ? [editButton, dropButton] : undefined;
}
private async onModelSelected(): Promise<void> {
this._onModelSelectionChanged.fire();
if (this._downloadedFile) {
await this._downloadedFile.close();
}
this._downloadedFile = undefined;
}
/**
* Returns selected data
*/
public get data(): ImportedModel[] | undefined {
return this._selectedModel;
}
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
if (!this._downloadedFile && this.data && this.data.length > 0) {
this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data[0]));
}
return this._downloadedFile;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,75 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ModelViewBase, UpdateModelEventName } from '../modelViewBase';
import * as constants from '../../../common/constants';
import { ApiWrapper } from '../../../common/apiWrapper';
import { DialogView } from '../../dialogView';
import { ModelDetailsEditPage } from './modelDetailsEditPage';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* Dialog to render registered model views
*/
export class EditModelDialog extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string,
private _parentView: ModelViewBase | undefined,
private _model: ImportedModel) {
super(apiWrapper, root);
this.dialogView = new DialogView(this._apiWrapper);
}
public dialogView: DialogView;
public editModelPage: ModelDetailsEditPage | undefined;
/**
* Opens a dialog to edit models.
*/
public open(): void {
this.editModelPage = new ModelDetailsEditPage(this._apiWrapper, this, this._model);
let registerModelButton = this._apiWrapper.createButton(constants.extLangSaveButtonText);
registerModelButton.onClick(async () => {
if (this.editModelPage) {
const valid = await this.editModelPage.validate();
if (valid) {
try {
await this.sendDataRequest(UpdateModelEventName, this.editModelPage?.data);
this.showInfoMessage(constants.modelUpdatedSuccessfully);
if (this._parentView) {
await this._parentView.refresh();
}
} catch (error) {
this.showInfoMessage(`${constants.modelUpdateFailedError} ${constants.getErrorMessage(error)}`);
}
}
}
});
let dialog = this.dialogView.createDialog(constants.editModelTitle, [this.editModelPage]);
dialog.customButtons = [registerModelButton];
this.mainViewPanel = dialog;
dialog.okButton.hidden = true;
dialog.cancelButton.label = constants.extLangDoneButtonText;
dialog.registerCloseValidator(() => {
return false; // Blocks Enter key from closing dialog.
});
this._apiWrapper.openDialog(dialog);
}
/**
* Resets the tabs for given provider Id
*/
public async refresh(): Promise<void> {
if (this.dialogView) {
this.dialogView.refresh();
}
}
}

View File

@@ -0,0 +1,113 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
import { AzureModelsComponent } from '../azureModelsComponent';
import * as constants from '../../../common/constants';
import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage';
import { ModelDetailsPage } from '../modelDetailsPage';
import { ModelBrowsePage } from '../modelBrowsePage';
import { ModelImportLocationPage } from './modelImportLocationPage';
/**
* Wizard to register a model
*/
export class ImportModelWizard extends ModelViewBase {
public modelSourcePage: ModelSourcePage | undefined;
public modelBrowsePage: ModelBrowsePage | undefined;
public modelDetailsPage: ModelDetailsPage | undefined;
public modelImportTargetPage: ModelImportLocationPage | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
constructor(
apiWrapper: ApiWrapper,
root: string,
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
}
/**
* Opens a dialog to manage packages used by notebooks.
*/
public async open(): Promise<void> {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this);
this.modelImportTargetPage = new ModelImportLocationPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelImportTargetPage, this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.azureRegisterModel;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
let validated: boolean = true;
if (pageInfo.newPage > pageInfo.lastPage) {
validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
}
if (validated && pageInfo.newPage === undefined) {
wizard.cancelButton.enabled = false;
wizard.backButton.enabled = false;
let result = await this.registerModel();
wizard.cancelButton.enabled = true;
wizard.backButton.enabled = true;
if (this._parentView) {
this._parentView.importTable = this.importTable;
await this._parentView.refresh();
}
return result;
}
return validated;
});
await wizard.open();
}
public get modelResources(): ModelSourcesComponent | undefined {
return this.modelSourcePage?.modelResources;
}
public get localModelsComponent(): LocalModelsComponent | undefined {
return this.modelBrowsePage?.localModelsComponent;
}
public get azureModelsComponent(): AzureModelsComponent | undefined {
return this.modelBrowsePage?.azureModelsComponent;
}
private async registerModel(): Promise<boolean> {
try {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
await this.importLocalModel(this.modelsViewData);
} else {
await this.importAzureModel(this.modelsViewData);
}
await this.storeImportConfigTable();
this.showInfoMessage(constants.modelRegisteredSuccessfully);
return true;
} catch (error) {
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
return false;
}
}
/**
* Refresh the pages
*/
public async refresh(): Promise<void> {
await this.wizardView?.refresh();
}
}

View File

@@ -0,0 +1,63 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { CurrentModelsComponent } from './currentModelsComponent';
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
import * as constants from '../../../common/constants';
import { ApiWrapper } from '../../../common/apiWrapper';
import { DialogView } from '../../dialogView';
/**
* Dialog to render registered model views
*/
export class ManageModelsDialog extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string) {
super(apiWrapper, root);
this.dialogView = new DialogView(this._apiWrapper);
}
public dialogView: DialogView;
public currentLanguagesTab: CurrentModelsComponent | undefined;
/**
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this, {
editable: true,
selectable: false
});
let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle);
registerModelButton.onClick(async () => {
await this.sendDataRequest(RegisterModelEventName, this.currentLanguagesTab?.modelTable?.importTable);
});
let dialog = this.dialogView.createDialog(constants.registerModelTitle, [this.currentLanguagesTab]);
dialog.customButtons = [registerModelButton];
this.mainViewPanel = dialog;
dialog.okButton.hidden = true;
dialog.cancelButton.label = constants.extLangDoneButtonText;
dialog.registerCloseValidator(() => {
return false; // Blocks Enter key from closing dialog.
});
this._apiWrapper.openDialog(dialog);
}
/**
* Resets the tabs for given provider Id
*/
public async refresh(): Promise<void> {
if (this.dialogView) {
this.dialogView.refresh();
}
}
}

View File

@@ -0,0 +1,154 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* View to render filters to pick an azure resource
*/
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<ImportedModel> {
private _form: azdata.FormContainer | undefined;
private _nameComponent: azdata.InputBoxComponent | undefined;
private _descriptionComponent: azdata.InputBoxComponent | undefined;
private _createdComponent: azdata.Component | undefined;
private _deployedComponent: azdata.Component | undefined;
private _frameworkComponent: azdata.Component | undefined;
private _frameworkVersionComponent: azdata.Component | undefined;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._createdComponent = modelBuilder.text().withProperties({
value: this._model.created
}).component();
this._deployedComponent = modelBuilder.text().withProperties({
value: this._model.deploymentTime
}).component();
this._frameworkComponent = modelBuilder.text().withProperties({
value: this._model.framework
}).component();
this._frameworkVersionComponent = modelBuilder.text().withProperties({
value: this._model.frameworkVersion
}).component();
this._nameComponent = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength,
value: this._model.modelName
}).component();
this._descriptionComponent = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength,
value: this._model.description,
multiline: true,
height: 50
}).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._nameComponent
},
{
title: '',
component: this._descriptionComponent
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
formBuilder.addFormItems([{
title: constants.modelName,
component: this._nameComponent
}, {
title: constants.modelCreated,
component: this._createdComponent
},
{
title: constants.modelDeployed,
component: this._deployedComponent
}, {
title: constants.modelFramework,
component: this._frameworkComponent
}, {
title: constants.modelFrameworkVersion,
component: this._frameworkVersionComponent
}, {
title: constants.modelDescription,
component: this._descriptionComponent
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._createdComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._frameworkComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._frameworkVersionComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._deployedComponent
});
formBuilder.removeFormItem({
title: constants.modelName,
component: this._nameComponent
});
formBuilder.removeFormItem({
title: constants.modelDescription,
component: this._descriptionComponent
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Returns selected data
*/
public get data(): ImportedModel | undefined {
let model = Object.assign({}, this._model);
model.modelName = this._nameComponent?.value || '';
model.description = this._descriptionComponent?.value || '';
return model;
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,85 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { ModelDetailsComponent } from './modelDetailsComponent';
/**
* View to pick model source
*/
export class ModelDetailsEditPage extends ModelViewBase implements IPageView, IDataComponent<ImportedModel> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelDetailsComponent: ModelDetailsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelDetailsComponent = new ModelDetailsComponent(this._apiWrapper, this, this._model);
this.modelDetailsComponent.registerComponent(modelBuilder);
this.modelDetailsComponent.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ImportedModel | undefined {
return this.modelDetailsComponent?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this.modelDetailsComponent) {
await this.modelDetailsComponent.refresh();
}
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelImportTargetPageTitle;
}
public async disposePage(): Promise<void> {
}
public async validate(): Promise<boolean> {
let validated = false;
if (this.data?.modelName) {
validated = true;
} else {
this.showErrorMessage(constants.modelNameRequiredError);
}
return validated;
}
}

View File

@@ -0,0 +1,97 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { TableSelectionComponent } from '../tableSelectionComponent';
import { DatabaseTable } from '../../../prediction/interfaces';
/**
* View to pick model source
*/
export class ModelImportLocationPage extends ModelViewBase implements IPageView, IDataComponent<DatabaseTable> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public tableSelectionComponent: TableSelectionComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, true);
this.tableSelectionComponent.onSelectedChanged(async () => {
await this.onTableSelected();
});
this.tableSelectionComponent.registerComponent(modelBuilder);
this.tableSelectionComponent.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
private async onTableSelected(): Promise<void> {
if (this.tableSelectionComponent?.data) {
this.importTable = this.tableSelectionComponent?.data;
}
}
/**
* Returns selected data
*/
public get data(): DatabaseTable | undefined {
return this.tableSelectionComponent?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this.tableSelectionComponent) {
await this.tableSelectionComponent.refresh();
}
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelImportTargetPageTitle;
}
public async disposePage(): Promise<void> {
}
public async validate(): Promise<boolean> {
let validated = false;
if (this.data?.databaseName && this.data?.tableName) {
validated = true;
validated = await this.verifyImportConfigTable(this.data);
if (!validated) {
this.showErrorMessage(constants.invalidImportTableSchemaError(this.data?.databaseName, this.data?.tableName));
}
} else {
this.showErrorMessage(constants.invalidImportTableError(this.data?.databaseName, this.data?.tableName));
}
return validated;
}
}

View File

@@ -0,0 +1,207 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType, ModelViewData } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import * as utils from '../../common/utils';
import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
/**
* View to pick model source
*/
export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
private _form: azdata.FormContainer | undefined;
private _title: string = constants.localModelPageTitle;
private _formBuilder: azdata.FormBuilder | undefined;
public localModelsComponent: LocalModelsComponent | undefined;
public azureModelsComponent: AzureModelsComponent | undefined;
public registeredModelsComponent: CurrentModelsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this, this._multiSelect);
this.localModelsComponent.registerComponent(modelBuilder);
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect);
this.azureModelsComponent.registerComponent(modelBuilder);
this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, {
selectable: true,
multiSelect: this._multiSelect,
editable: false
});
this.registeredModelsComponent.registerComponent(modelBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelViewData[] {
return this.modelsViewData;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.modelSourceType === ModelSourceType.Local) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.azureModelsComponent.removeComponents(this._formBuilder);
this.registeredModelsComponent.removeComponents(this._formBuilder);
this.localModelsComponent.addComponents(this._formBuilder);
await this.localModelsComponent.refresh();
}
} else if (this.modelSourceType === ModelSourceType.Azure) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.localModelsComponent.removeComponents(this._formBuilder);
this.azureModelsComponent.addComponents(this._formBuilder);
this.registeredModelsComponent.removeComponents(this._formBuilder);
await this.azureModelsComponent.refresh();
}
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.localModelsComponent.removeComponents(this._formBuilder);
this.azureModelsComponent.removeComponents(this._formBuilder);
this.registeredModelsComponent.addComponents(this._formBuilder);
await this.registeredModelsComponent.refresh();
}
}
}
this.loadTitle();
}
private loadTitle(): void {
if (this.modelSourceType === ModelSourceType.Local) {
this._title = constants.localModelPageTitle;
} else if (this.modelSourceType === ModelSourceType.Azure) {
this._title = constants.azureModelPageTitle;
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
this._title = constants.importedModelsPageTitle;
} else {
this._title = constants.modelSourcePageTitle;
}
}
/**
* Returns page title
*/
public get title(): string {
this.loadTitle();
return this._title;
}
public validate(): Promise<boolean> {
let validated = false;
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0;
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
validated = this.registeredModelsComponent.data !== undefined && this.registeredModelsComponent.data.length > 0;
}
if (!validated) {
this.showErrorMessage(constants.invalidModelToSelectError);
}
return Promise.resolve(validated);
}
public onEnter(): Promise<void> {
return Promise.resolve();
}
public async onLeave(): Promise<void> {
this.modelsViewData = [];
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
if (this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0) {
this.modelsViewData = this.localModelsComponent.data.map(x => {
const fileName = utils.getFileName(x);
return {
modelData: x,
modelDetails: {
modelName: fileName,
fileName: fileName
},
targetImportTable: this.importTable
};
});
}
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
if (this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0) {
this.modelsViewData = this.azureModelsComponent.data.map(x => {
return {
modelData: {
account: x.account,
subscription: x.subscription,
group: x.group,
workspace: x.workspace,
model: x.model
},
modelDetails: {
modelName: x.model?.name || '',
fileName: x.model?.name,
framework: x.model?.framework,
frameworkVersion: x.model?.frameworkVersion,
created: x.model?.createdTime
},
targetImportTable: this.importTable
};
});
}
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
if (this.registeredModelsComponent.data !== undefined) {
this.modelsViewData = this.registeredModelsComponent.data.map(x => {
return {
modelData: x,
modelDetails: {
modelName: ''
},
targetImportTable: this.importTable
};
});
}
}
}
public async disposePage(): Promise<void> {
if (this.azureModelsComponent) {
await this.azureModelsComponent.disposeComponent();
}
if (this.registeredModelsComponent) {
await this.registeredModelsComponent.disposeComponent();
}
}
}

View File

@@ -0,0 +1,83 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelViewData } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { ModelsDetailsTableComponent } from './modelsDetailsTableComponent';
/**
* View to pick model details
*/
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelDetails: ModelsDetailsTableComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelDetails = new ModelsDetailsTableComponent(this._apiWrapper, modelBuilder, this);
this.modelDetails.registerComponent(modelBuilder);
this.modelDetails.addComponents(this._formBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelViewData[] | undefined {
return this.modelDetails?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this.modelDetails) {
await this.modelDetails.refresh();
}
}
public async onEnter(): Promise<void> {
await this.refresh();
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelDetailsPageTitle;
}
public validate(): Promise<boolean> {
if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.modelName)) {
return Promise.resolve(true);
} else {
this.showErrorMessage(constants.modelNameRequiredError);
return Promise.resolve(false);
}
}
}

View File

@@ -0,0 +1,425 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ImportedModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { ManageModelsDialog } from './manageModels/manageModelsDialog';
import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { ImportModelWizard } from './manageModels/importModelWizard';
import * as fs from 'fs';
import * as constants from '../../common/constants';
import { PredictWizard } from './prediction/predictWizard';
import { AzureModelResource } from '../interfaces';
import { PredictService } from '../../prediction/predictService';
import { EditModelDialog } from './manageModels/editModelDialog';
/**
* Model management UI controller
*/
export class ModelManagementController extends ControllerBase {
/**
* Creates new instance
*/
constructor(
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
private _registeredModelService: DeployedModelService,
private _predictService: PredictService) {
super(apiWrapper);
}
/**
* Opens the dialog for model registration
* @param parent parent if the view is opened from another view
* @param controller controller
* @param apiWrapper apiWrapper
* @param root root folder path
*/
public async registerModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
controller = controller || this;
apiWrapper = apiWrapper || this._apiWrapper;
root = root || this._root;
let view = new ImportModelWizard(apiWrapper, root, parent);
if (importTable) {
view.importTable = importTable;
} else {
view.importTable = await controller._registeredModelService.getRecentImportTable();
}
controller.registerEvents(view);
// Open view
//
await view.open();
await view.refresh();
return view;
}
/**
* Opens the dialog to edit model
*/
public async editModel(model: ImportedModel, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
controller = controller || this;
apiWrapper = apiWrapper || this._apiWrapper;
root = root || this._root;
let view = new EditModelDialog(apiWrapper, root, parent, model);
controller.registerEvents(view);
// Open view
//
await view.open();
await view.refresh();
return view;
}
/**
* Opens the wizard for prediction
*/
public async predictModel(): Promise<ModelViewBase> {
let view = new PredictWizard(this._apiWrapper, this._root);
view.importTable = await this._registeredModelService.getRecentImportTable();
this.registerEvents(view);
view.on(LoadModelParametersEventName, async () => {
const modelArtifact = await view.getModelFileName();
await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
modelArtifact?.filePath);
});
// Open view
//
await view.open();
await view.refresh();
return view;
}
/**
* Register events in the main view
* @param view main view
*/
public registerEvents(view: ModelViewBase): void {
// Register events
//
super.registerEvents(view);
view.on(ListAccountsEventName, async () => {
await this.executeAction(view, ListAccountsEventName, this.getAzureAccounts, this._amlService);
});
view.on(ListSubscriptionsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListSubscriptionsEventName, this.getAzureSubscriptions, this._amlService, azureArgs.account);
});
view.on(ListWorkspacesEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListWorkspacesEventName, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group);
});
view.on(ListGroupsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListGroupsEventName, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription);
});
view.on(ListAzureModelsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService
, azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace);
});
view.on(ListModelsEventName, async (args) => {
const table = <DatabaseTable>args;
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService, table);
});
view.on(RegisterLocalModelEventName, async (arg) => {
let models = <ModelViewData[]>arg;
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models);
view.refresh();
});
view.on(RegisterModelEventName, async (args) => {
const importTable = <DatabaseTable>args;
await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root);
});
view.on(EditModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, EditModelEventName, this.editModel, model, view, this, this._apiWrapper, this._root);
});
view.on(UpdateModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, UpdateModelEventName, this.updateModel, this._registeredModelService, model);
});
view.on(DeleteModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, DeleteModelEventName, this.deleteModel, this._registeredModelService, model);
});
view.on(RegisterAzureModelEventName, async (arg) => {
let models = <ModelViewData[]>arg;
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
models);
});
view.on(DownloadAzureModelEventName, async (arg) => {
let registerArgs = <AzureModelResource>arg;
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
});
view.on(ListDatabaseNamesEventName, async () => {
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
});
view.on(ListTableNamesEventName, async (arg) => {
let dbName = <string>arg;
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
});
view.on(ListColumnNamesEventName, async (arg) => {
let tableColumnsArgs = <DatabaseTable>arg;
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
tableColumnsArgs);
});
view.on(PredictModelEventName, async (arg) => {
let predictArgs = <PredictModelEventArgs>arg;
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
predictArgs, predictArgs.model, predictArgs.filePath);
});
view.on(DownloadRegisteredModelEventName, async (arg) => {
let model = <ImportedModel>arg;
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
model);
});
view.on(StoreImportTableEventName, async (arg) => {
let importTable = <DatabaseTable>arg;
await this.executeAction(view, StoreImportTableEventName, this.storeImportTable, this._registeredModelService,
importTable);
});
view.on(VerifyImportTableEventName, async (arg) => {
let importTable = <DatabaseTable>arg;
await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService,
importTable);
});
view.on(SourceModelSelectedEventName, async (arg) => {
view.modelSourceType = <ModelSourceType>arg;
await view.refresh();
});
view.on(SignInToAzureEventName, async () => {
await this.executeAction(view, SignInToAzureEventName, this.signInToAzure, this._amlService);
await view.refresh();
});
}
/**
* Opens the dialog for model management
*/
public async manageRegisteredModels(importTable?: DatabaseTable): Promise<ModelViewBase> {
let view = new ManageModelsDialog(this._apiWrapper, this._root);
if (importTable) {
view.importTable = importTable;
} else {
view.importTable = await this._registeredModelService.getRecentImportTable();
}
// Register events
//
this.registerEvents(view);
// Open view
//
view.open();
return view;
}
private async signInToAzure(service: AzureModelRegistryService): Promise<void> {
return await service.signInToAzure();
}
private async getAzureAccounts(service: AzureModelRegistryService): Promise<azdata.Account[]> {
return await service.getAccounts();
}
private async getAzureSubscriptions(service: AzureModelRegistryService, account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[] | undefined> {
return await service.getSubscriptions(account);
}
private async getAzureGroups(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[] | undefined> {
return await service.getGroups(account, subscription);
}
private async getWorkspaces(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[] | undefined> {
if (!account || !subscription) {
return [];
}
return await service.getWorkspaces(account, subscription, group);
}
private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise<ImportedModel[]> {
return registeredModelService.getDeployedModels(table);
}
private async getAzureModels(
service: AzureModelRegistryService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined): Promise<WorkspaceModel[]> {
if (!account || !subscription || !resourceGroup || !workspace) {
return [];
}
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise<void> {
if (models) {
await Promise.all(models.map(async (model) => {
if (model && model.targetImportTable) {
const localModel = <string>model.modelData;
if (localModel) {
await service.deployLocalModel(localModel, model.modelDetails, model.targetImportTable);
}
} else {
throw Error(constants.invalidModelToRegisterError);
}
}));
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async updateModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
if (model) {
await service.updateModel(model);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async deleteModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
if (model) {
await service.deleteModel(model);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async registerAzureModel(
azureService: AzureModelRegistryService,
service: DeployedModelService,
models: ModelViewData[] | undefined): Promise<void> {
if (!models) {
throw Error(constants.invalidAzureResourceError);
}
await Promise.all(models.map(async (model) => {
if (model && model.targetImportTable) {
const azureModel = <AzureModelResource>model.modelData;
if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) {
let filePath: string | undefined;
try {
const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group,
azureModel.workspace, azureModel.model);
if (filePath) {
await service.deployLocalModel(filePath, model.modelDetails, model.targetImportTable);
} else {
throw Error(constants.invalidModelToRegisterError);
}
} finally {
if (filePath) {
await fs.promises.unlink(filePath);
}
}
}
} else {
throw Error(constants.invalidModelToRegisterError);
}
}));
}
private async getDatabaseList(predictService: PredictService): Promise<string[]> {
return await predictService.getDatabaseList();
}
private async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
return await predictService.getTableList(databaseName);
}
private async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
return await predictService.getTableColumnsList(databaseTable);
}
private async generatePredictScript(
predictService: PredictService,
predictParams: PredictParameters,
registeredModel: ImportedModel | undefined,
filePath: string | undefined
): Promise<string> {
if (!predictParams) {
throw Error(constants.invalidModelToPredictError);
}
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
return result;
}
private async storeImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<void> {
if (table) {
await registeredModelService.storeRecentImportTable(table);
} else {
throw Error(constants.invalidImportTableError(undefined, undefined));
}
}
private async verifyImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<boolean> {
if (table) {
return await registeredModelService.verifyConfigTable(table);
} else {
throw Error(constants.invalidImportTableError(undefined, undefined));
}
}
private async downloadRegisteredModel(
registeredModelService: DeployedModelService,
model: ImportedModel | undefined): Promise<string> {
if (!model) {
throw Error(constants.invalidModelToPredictError);
}
return await registeredModelService.downloadModel(model);
}
private async loadModelParameters(
registeredModelService: DeployedModelService,
model: string | undefined): Promise<ModelParameters | undefined> {
if (!model) {
return undefined;
}
return await registeredModelService.loadModelParameters(model);
}
private async downloadAzureModel(
azureService: AzureModelRegistryService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined): Promise<string> {
if (!account || !subscription || !resourceGroup || !workspace || !model) {
throw Error(constants.invalidAzureResourceError);
}
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
return filePath;
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
}

View File

@@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { ModelSourcesComponent } from './modelSourcesComponent';
/**
* View to pick model source
*/
export class ModelSourcePage extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelResources: ModelSourcesComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
this.modelResources.registerComponent(modelBuilder);
this.modelResources.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelSourceType {
return this.modelResources?.data || ModelSourceType.Local;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelSourcePageTitle;
}
public async disposePage(): Promise<void> {
}
}

View File

@@ -0,0 +1,156 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
/**
* View to pick model source
*/
export class ModelSourcesComponent extends ModelViewBase implements IDataComponent<ModelSourceType> {
private _form: azdata.FormContainer | undefined;
private _flexContainer: azdata.FlexContainer | undefined;
private _amlModel: azdata.CardComponent | undefined;
private _localModel: azdata.CardComponent | undefined;
private _registeredModels: azdata.CardComponent | undefined;
private _sourceType: ModelSourceType = ModelSourceType.Local;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._localModel = modelBuilder.card()
.withProperties({
value: 'local',
name: 'modelLocation',
label: constants.localModelSource,
selected: this._options[0] === ModelSourceType.Local,
cardType: azdata.CardType.VerticalButton,
width: 50
}).component();
this._amlModel = modelBuilder.card()
.withProperties({
value: 'aml',
name: 'modelLocation',
label: constants.azureModelSource,
selected: this._options[0] === ModelSourceType.Azure,
cardType: azdata.CardType.VerticalButton,
width: 50
}).component();
this._registeredModels = modelBuilder.card()
.withProperties({
value: 'registered',
name: 'modelLocation',
label: constants.registeredModelsSource,
selected: this._options[0] === ModelSourceType.RegisteredModels,
cardType: azdata.CardType.VerticalButton,
width: 50
}).component();
this._localModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Local;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._amlModel && this._registeredModels) {
this._amlModel.selected = false;
this._registeredModels.selected = false;
}
});
this._amlModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Azure;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._registeredModels) {
this._localModel.selected = false;
this._registeredModels.selected = false;
}
});
this._registeredModels.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.RegisteredModels;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._amlModel) {
this._localModel.selected = false;
this._amlModel.selected = false;
}
});
let components: azdata.Component[] = [];
this._options.forEach(option => {
switch (option) {
case ModelSourceType.Local:
if (this._localModel) {
components.push(this._localModel);
}
break;
case ModelSourceType.Azure:
if (this._amlModel) {
components.push(this._amlModel);
}
break;
case ModelSourceType.RegisteredModels:
if (this._registeredModels) {
components.push(this._registeredModels);
}
break;
}
});
this._sourceType = this._options[0];
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
this._flexContainer = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
}).withItems(components).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._flexContainer
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
formBuilder.addFormItem({ title: '', component: this._flexContainer });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
formBuilder.removeFormItem({ title: '', component: this._flexContainer });
}
}
/**
* Returns selected data
*/
public get data(): ModelSourceType {
return this._sourceType;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,328 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { ViewBase } from '../viewBase';
import { ImportedModel, WorkspaceModel, ImportedModelDetails, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
}
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
details?: ImportedModelDetails
}
export interface PredictModelEventArgs extends PredictParameters {
model?: ImportedModel;
filePath?: string;
}
export enum ModelSourceType {
Local,
Azure,
RegisteredModels
}
export interface ModelViewData {
modelFile?: string;
modelData: AzureModelResource | string | ImportedModel;
modelDetails?: ImportedModelDetails;
targetImportTable?: DatabaseTable;
}
// Event names
//
export const ListModelsEventName = 'listModels';
export const ListAzureModelsEventName = 'listAzureModels';
export const ListAccountsEventName = 'listAccounts';
export const ListDatabaseNamesEventName = 'listDatabaseNames';
export const ListTableNamesEventName = 'listTableNames';
export const ListColumnNamesEventName = 'listColumnNames';
export const ListSubscriptionsEventName = 'listSubscriptions';
export const ListGroupsEventName = 'listGroups';
export const ListWorkspacesEventName = 'listWorkspaces';
export const RegisterLocalModelEventName = 'registerLocalModel';
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
export const PredictModelEventName = 'predictModel';
export const RegisterModelEventName = 'registerModel';
export const EditModelEventName = 'editModel';
export const UpdateModelEventName = 'updateModel';
export const DeleteModelEventName = 'deleteModel';
export const SourceModelSelectedEventName = 'sourceModelSelected';
export const LoadModelParametersEventName = 'loadModelParameters';
export const StoreImportTableEventName = 'storeImportTable';
export const VerifyImportTableEventName = 'verifyImportTable';
export const SignInToAzureEventName = 'signInToAzure';
/**
* Base class for all model management views
*/
export abstract class ModelViewBase extends ViewBase {
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
private _modelsViewData: ModelViewData[] = [];
private _importTable: DatabaseTable | undefined;
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
super(apiWrapper, root, parent);
}
protected getEventNames(): string[] {
return super.getEventNames().concat([ListModelsEventName,
ListAzureModelsEventName,
ListAccountsEventName,
ListSubscriptionsEventName,
ListGroupsEventName,
ListWorkspacesEventName,
RegisterLocalModelEventName,
RegisterAzureModelEventName,
RegisterModelEventName,
SourceModelSelectedEventName,
ListDatabaseNamesEventName,
ListTableNamesEventName,
ListColumnNamesEventName,
PredictModelEventName,
DownloadAzureModelEventName,
DownloadRegisteredModelEventName,
LoadModelParametersEventName,
StoreImportTableEventName,
VerifyImportTableEventName,
EditModelEventName,
UpdateModelEventName,
DeleteModelEventName,
SignInToAzureEventName]);
}
/**
* Parent view
*/
public get parent(): ModelViewBase | undefined {
return this._parent ? <ModelViewBase>this._parent : undefined;
}
/**
* list azure models
*/
public async listAzureModels(workspaceResource: AzureWorkspaceResource): Promise<WorkspaceModel[]> {
const args: AzureResourceEventArgs = workspaceResource;
return await this.sendDataRequest(ListAzureModelsEventName, args);
}
/**
* list registered models
*/
public async listModels(table: DatabaseTable): Promise<ImportedModel[]> {
return await this.sendDataRequest(ListModelsEventName, table);
}
/**
* lists azure accounts
*/
public async listAzureAccounts(): Promise<azdata.Account[]> {
return await this.sendDataRequest(ListAccountsEventName);
}
/**
* lists database names
*/
public async listDatabaseNames(): Promise<string[]> {
return await this.sendDataRequest(ListDatabaseNamesEventName);
}
/**
* lists table names
*/
public async listTableNames(dbName: string): Promise<DatabaseTable[]> {
return await this.sendDataRequest(ListTableNamesEventName, dbName);
}
/**
* lists column names
*/
public async listColumnNames(table: DatabaseTable): Promise<TableColumn[]> {
return await this.sendDataRequest(ListColumnNamesEventName, table);
}
/**
* lists azure subscriptions
* @param account azure account
*/
public async listAzureSubscriptions(account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[]> {
const args: AzureResourceEventArgs = {
account: account
};
return await this.sendDataRequest(ListSubscriptionsEventName, args);
}
/**
* registers local model
* @param localFilePath local file path
*/
public async importLocalModel(models: ModelViewData[]): Promise<void> {
return await this.sendDataRequest(RegisterLocalModelEventName, models);
}
/**
* downloads registered model
* @param model model to download
*/
public async downloadRegisteredModel(model: ImportedModel | undefined): Promise<string> {
return await this.sendDataRequest(DownloadRegisteredModelEventName, model);
}
/**
* download azure model
* @param args azure resource
*/
public async downloadAzureModel(resource: AzureModelResource | undefined): Promise<string> {
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
}
/**
* Loads model parameters
*/
public async loadModelParameters(): Promise<ModelParameters | undefined> {
return await this.sendDataRequest(LoadModelParametersEventName);
}
/**
* registers azure model
* @param args azure resource
*/
public async importAzureModel(models: ModelViewData[]): Promise<void> {
return await this.sendDataRequest(RegisterAzureModelEventName, models);
}
/**
* Stores the name of the table as recent config table for importing models
*/
public async storeImportConfigTable(): Promise<void> {
await this.sendRequest(StoreImportTableEventName, this.importTable);
}
/**
* Verifies if table is valid to import models to
*/
public async verifyImportConfigTable(table: DatabaseTable): Promise<boolean> {
return await this.sendDataRequest(VerifyImportTableEventName, table);
}
/**
* registers azure model
* @param args azure resource
*/
public async generatePredictScript(model: ImportedModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> {
const args: PredictModelEventArgs = Object.assign({}, params, {
model: model,
filePath: filePath,
loadFromRegisteredModel: !filePath
});
return await this.sendDataRequest(PredictModelEventName, args);
}
/**
* list resource groups
* @param account azure account
* @param subscription azure subscription
*/
public async listAzureGroups(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[]> {
const args: AzureResourceEventArgs = {
account: account,
subscription: subscription
};
return await this.sendDataRequest(ListGroupsEventName, args);
}
/**
* Sets model source type
*/
public set modelSourceType(value: ModelSourceType) {
if (this.parent) {
this.parent.modelSourceType = value;
} else {
this._modelSourceType = value;
}
}
/**
* Returns model source type
*/
public get modelSourceType(): ModelSourceType {
if (this.parent) {
return this.parent.modelSourceType;
} else {
return this._modelSourceType;
}
}
/**
* Sets model data
*/
public set modelsViewData(value: ModelViewData[]) {
if (this.parent) {
this.parent.modelsViewData = value;
} else {
this._modelsViewData = value;
}
}
/**
* Returns model data
*/
public get modelsViewData(): ModelViewData[] {
if (this.parent) {
return this.parent.modelsViewData;
} else {
return this._modelsViewData;
}
}
/**
* Sets import table
*/
public set importTable(value: DatabaseTable | undefined) {
if (this.parent) {
this.parent.importTable = value;
} else {
this._importTable = value;
}
}
/**
* Returns import table
*/
public get importTable(): DatabaseTable | undefined {
if (this.parent) {
return this.parent.importTable;
} else {
return this._importTable;
}
}
/**
* lists azure workspaces
* @param account azure account
* @param subscription azure subscription
* @param group azure resource group
*/
public async listWorkspaces(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[]> {
const args: AzureResourceEventArgs = {
account: account,
subscription: subscription,
group: group
};
return await this.sendDataRequest(ListWorkspacesEventName, args);
}
}

View File

@@ -0,0 +1,188 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelViewData } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
/**
* View to pick local models file
*/
export class ModelsDetailsTableComponent extends ModelViewBase implements IDataComponent<ModelViewData[]> {
private _table: azdata.DeclarativeTableComponent | undefined;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Name
displayName: constants.modelFileName,
ariaLabel: constants.modelFileName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelDescription,
ariaLabel: constants.modelDescription,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._table) {
formBuilder.addFormItems([{
title: '',
component: this._table
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._table) {
formBuilder.removeFormItem({
title: '',
component: this._table
});
}
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadData(): Promise<void> {
const models = this.modelsViewData;
if (this._table && models) {
let tableData: any[][] = [];
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
this._table.data = tableData;
}
}
private createTableRow(model: ModelViewData | undefined): any[] {
if (this._modelBuilder && model && model.modelDetails) {
const nameComponent = this._modelBuilder.inputBox().withProperties({
value: model.modelDetails.modelName,
width: this.componentMaxLength - 100,
required: true
}).component();
const descriptionComponent = this._modelBuilder.inputBox().withProperties({
value: model.modelDetails.description,
width: this.componentMaxLength
}).component();
descriptionComponent.onTextChanged(() => {
if (model.modelDetails) {
model.modelDetails.description = descriptionComponent.value;
}
});
nameComponent.onTextChanged(() => {
if (model.modelDetails) {
model.modelDetails.modelName = nameComponent.value || '';
}
});
let deleteButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
width: 15,
height: 15,
iconPath: {
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
}).component();
deleteButton.onDidClick(async () => {
this.modelsViewData = this.modelsViewData.filter(x => x !== model);
await this.refresh();
});
return [model.modelDetails.fileName, nameComponent, descriptionComponent, deleteButton];
}
return [];
}
/**
* Returns selected data
*/
public get data(): ModelViewData[] {
return this.modelsViewData;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._table;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,101 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { InputColumnsComponent } from './inputColumnsComponent';
import { OutputColumnsComponent } from './outputColumnsComponent';
import { PredictParameters } from '../../../prediction/interfaces';
/**
* View to pick model source
*/
export class ColumnsSelectionPage extends ModelViewBase implements IPageView, IDataComponent<PredictParameters> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public inputColumnsComponent: InputColumnsComponent | undefined;
public outputColumnsComponent: OutputColumnsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.inputColumnsComponent = new InputColumnsComponent(this._apiWrapper, this);
this.inputColumnsComponent.registerComponent(modelBuilder);
this.inputColumnsComponent.addComponents(this._formBuilder);
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
this.outputColumnsComponent.registerComponent(modelBuilder);
this.outputColumnsComponent.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): PredictParameters | undefined {
return this.inputColumnsComponent?.data && this.outputColumnsComponent?.data ?
Object.assign({}, this.inputColumnsComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
undefined;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.inputColumnsComponent) {
await this.inputColumnsComponent.refresh();
}
if (this.outputColumnsComponent) {
await this.outputColumnsComponent.refresh();
}
}
}
public async onEnter(): Promise<void> {
await this.inputColumnsComponent?.onLoading();
await this.outputColumnsComponent?.onLoading();
try {
const modelParameters = await this.loadModelParameters();
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
this.inputColumnsComponent.modelParameters = modelParameters;
this.outputColumnsComponent.modelParameters = modelParameters;
await this.inputColumnsComponent.refresh();
await this.outputColumnsComponent.refresh();
}
} catch (error) {
this.showErrorMessage(constants.loadModelParameterFailedError, error);
}
await this.inputColumnsComponent?.onLoaded();
await this.outputColumnsComponent?.onLoaded();
}
/**
* Returns page title
*/
public get title(): string {
return constants.columnSelectionPageTitle;
}
}

View File

@@ -0,0 +1,302 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IDataComponent } from '../../interfaces';
import { PredictColumn, DatabaseTable, TableColumn } from '../../../prediction/interfaces';
import { ModelParameter, ModelParameters } from '../../../modelManagement/interfaces';
/**
* View to render azure models in a table
*/
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
private _table: azdata.DeclarativeTableComponent | undefined;
private _parameters: PredictColumn[] = [];
private _loader: azdata.LoadingComponent;
private _dataTypes: string[] = [
'bigint',
'int',
'smallint',
'real',
'float',
'varchar(MAX)',
'bit'
];
/**
* Creates a view to render azure models in a table
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _forInput: boolean = true) {
super(apiWrapper, parent.root, parent);
this._loader = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.LoadingComponent {
let columnHeader: azdata.DeclarativeTableColumn[];
if (this._forInput) {
columnHeader = [
{ // Action
displayName: constants.columnName,
ariaLabel: constants.columnName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: '',
ariaLabel: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.inputName,
ariaLabel: constants.inputName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 120,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
} else {
columnHeader = [
{ // Name
displayName: constants.outputName,
ariaLabel: constants.outputName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 200,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.displayName,
ariaLabel: constants.displayName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.dataTypeName,
ariaLabel: constants.dataTypeName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
}
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: columnHeader,
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
this._loader = modelBuilder.loadingComponent()
.withItem(this._table)
.withProperties({
loading: true
}).component();
return this._loader;
}
public async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
public async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
public get component(): azdata.Component {
return this._loader;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadInputs(modelParameters: ModelParameters | undefined, table: DatabaseTable): Promise<void> {
await this.onLoading();
this._parameters = [];
let tableData: any[][] = [];
if (this._table) {
if (this._forInput) {
const columns = await this.listColumnNames(table);
if (modelParameters?.inputs && columns) {
tableData = tableData.concat(modelParameters.inputs.map(input => this.createInputTableRow(input, columns)));
}
}
this._table.data = tableData;
}
await this.onLoaded();
}
public async loadOutputs(modelParameters: ModelParameters | undefined): Promise<void> {
this.onLoading();
this._parameters = [];
let tableData: any[][] = [];
if (this._table) {
if (!this._forInput) {
if (modelParameters?.outputs && this._dataTypes) {
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
}
}
this._table.data = tableData;
}
this.onLoaded();
}
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
if (this._modelBuilder) {
let nameInput = this._modelBuilder.dropDown().withProperties({
values: dataTypes,
width: this.componentMaxLength
}).component();
const name = modelParameter.name;
const dataType = dataTypes.find(x => x === modelParameter.type);
if (dataType) {
nameInput.value = dataType;
}
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
nameInput.onValueChanged(() => {
const value = <string>nameInput.value;
if (value !== modelParameter.type) {
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.dataType = value;
}
}
});
let displayNameInput = this._modelBuilder.inputBox().withProperties({
value: name,
width: 200
}).component();
displayNameInput.onTextChanged(() => {
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.columnName = displayNameInput.value || name;
}
});
return [`${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
}
return [];
}
private createInputTableRow(modelParameter: ModelParameter, columns: TableColumn[] | undefined): any[] {
if (this._modelBuilder && columns) {
const values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; });
let nameInput = this._modelBuilder.dropDown().withProperties({
values: values,
width: this.componentMaxLength
}).component();
const name = modelParameter.name;
let column = values.find(x => x.name === modelParameter.name);
if (!column) {
column = values[0];
}
nameInput.value = column;
this._parameters.push({ columnName: column.name, paramName: name });
nameInput.onValueChanged(() => {
const selectedColumn = nameInput.value;
const value = selectedColumn ? (<azdata.CategoryValue>selectedColumn).name : undefined;
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.columnName = value || '';
}
});
const label = this._modelBuilder.inputBox().withProperties({
value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`,
enabled: false,
width: this.componentMaxLength
}).component();
const image = this._modelBuilder.image().withProperties({
width: 50,
height: 50,
iconPath: {
dark: this.asAbsolutePath('images/arrow.svg'),
light: this.asAbsolutePath('images/arrow.svg')
},
iconWidth: 20,
iconHeight: 20,
title: 'maps'
}).component();
return [nameInput, image, label];
}
return [];
}
/**
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
return this._parameters;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,142 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
import { ModelParameters } from '../../../modelManagement/interfaces';
import { ColumnsTable } from './columnsTable';
import { TableSelectionComponent } from '../tableSelectionComponent';
/**
* View to render filters to pick an azure resource
*/
export class InputColumnsComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
private _form: azdata.FormContainer | undefined;
private _tableSelectionComponent: TableSelectionComponent | undefined;
private _columns: ColumnsTable | undefined;
private _modelParameters: ModelParameters | undefined;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
this._tableSelectionComponent.registerComponent(modelBuilder);
this._tableSelectionComponent.onSelectedChanged(async () => {
await this.onTableSelected();
});
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.inputColumns,
component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
formBuilder.addFormItems([{
title: '',
component: this._tableSelectionComponent.component
}, {
title: constants.inputColumns,
component: this._columns.component
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
formBuilder.removeFormItem({
title: '',
component: this._tableSelectionComponent.component
});
formBuilder.removeFormItem({
title: constants.inputColumns,
component: this._columns.component
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Returns selected data
*/
public get data(): PredictInputParameters | undefined {
return Object.assign({}, this.databaseTable, {
inputColumns: this.columnNames
});
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
if (this._tableSelectionComponent) {
this._tableSelectionComponent.refresh();
}
}
public set modelParameters(value: ModelParameters) {
this._modelParameters = value;
}
public async onLoading(): Promise<void> {
if (this._columns) {
await this._columns.onLoading();
}
}
public async onLoaded(): Promise<void> {
if (this._columns) {
await this._columns.onLoaded();
}
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
private async onTableSelected(): Promise<void> {
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
}
private get databaseTable(): DatabaseTable {
let selectedItem = this._tableSelectionComponent?.data;
return {
databaseName: selectedItem?.databaseName,
tableName: selectedItem?.tableName,
schema: selectedItem?.schema
};
}
private get columnNames(): PredictColumn[] | undefined {
return this._columns?.data;
}
}

View File

@@ -0,0 +1,35 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as utils from '../../../common/utils';
/**
* Wizard to register a model
*/
export class ModelArtifact {
/**
* Creates new model artifact
*/
constructor(private _filePath: string, private _deleteAtClose: boolean = true) {
}
public get filePath(): string {
return this._filePath;
}
/**
* Closes the artifact and disposes the resources
*/
public async close(): Promise<void> {
if (this._deleteAtClose) {
try {
await utils.deleteFile(this._filePath);
} catch {
}
}
}
}

View File

@@ -0,0 +1,109 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { PredictColumn } from '../../../prediction/interfaces';
import { ColumnsTable } from './columnsTable';
import { ModelParameters } from '../../../modelManagement/interfaces';
/**
* View to render filters to pick an azure resource
*/
export class OutputColumnsComponent extends ModelViewBase implements IDataComponent<PredictColumn[]> {
private _form: azdata.FormContainer | undefined;
private _columns: ColumnsTable | undefined;
private _modelParameters: ModelParameters | undefined;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this, false);
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._columns) {
formBuilder.addFormItems([{
title: constants.outputColumns,
component: this._columns.component
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._columns) {
formBuilder.removeFormItem({
title: constants.outputColumns,
component: this._columns.component
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
if (this._modelParameters) {
this._columns?.loadOutputs(this._modelParameters);
}
}
public set modelParameters(value: ModelParameters) {
this._modelParameters = value;
}
public async onLoading(): Promise<void> {
if (this._columns) {
await this._columns.onLoading();
}
}
public async onLoaded(): Promise<void> {
if (this._columns) {
await this._columns.onLoaded();
}
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
/**
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
return this._columns?.data;
}
}

View File

@@ -0,0 +1,159 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
import { AzureModelsComponent } from '../azureModelsComponent';
import * as constants from '../../../common/constants';
import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage';
import { ColumnsSelectionPage } from './columnsSelectionPage';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { ModelArtifact } from './modelArtifact';
import { ModelBrowsePage } from '../modelBrowsePage';
/**
* Wizard to register a model
*/
export class PredictWizard extends ModelViewBase {
public modelSourcePage: ModelSourcePage | undefined;
public columnsSelectionPage: ColumnsSelectionPage | undefined;
public modelBrowsePage: ModelBrowsePage | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
constructor(
apiWrapper: ApiWrapper,
root: string,
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
}
/**
* Opens a dialog to manage packages used by notebooks.
*/
public async open(): Promise<void> {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this, false);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
[this.modelSourcePage,
this.modelBrowsePage,
this.columnsSelectionPage]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.predictModel;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.doneButton.onClick(async () => {
await this.onClose();
});
wizard.cancelButton.onClick(async () => {
await this.onClose();
});
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
let validated: boolean = true;
if (pageInfo.newPage > pageInfo.lastPage) {
validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
}
if (validated) {
if (pageInfo.newPage === undefined) {
this.onLoading();
await this.predict();
this.onLoaded();
if (this._parentView) {
this._parentView?.refresh();
}
}
return true;
}
return validated;
});
await wizard.open();
}
private onLoading(): void {
this.refreshButtons(true);
}
private onLoaded(): void {
this.refreshButtons(false);
}
private refreshButtons(loading: boolean): void {
if (this.wizardView && this.wizardView.wizard) {
this.wizardView.wizard.cancelButton.enabled = !loading;
this.wizardView.wizard.cancelButton.enabled = !loading;
}
}
public get modelResources(): ModelSourcesComponent | undefined {
return this.modelSourcePage?.modelResources;
}
public get localModelsComponent(): LocalModelsComponent | undefined {
return this.modelBrowsePage?.localModelsComponent;
}
public get azureModelsComponent(): AzureModelsComponent | undefined {
return this.modelBrowsePage?.azureModelsComponent;
}
public async getModelFileName(): Promise<ModelArtifact | undefined> {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
return new ModelArtifact(this.localModelsComponent.data[0], false);
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
return await this.azureModelsComponent.getDownloadedModel();
} else if (this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) {
return await this.modelBrowsePage.registeredModelsComponent.modelTable?.getDownloadedModel();
}
return undefined;
}
private async predict(): Promise<boolean> {
try {
let modelFilePath: string | undefined;
let registeredModel: ImportedModel | undefined = undefined;
if (this.modelResources && this.modelResources.data && this.modelResources.data === ModelSourceType.RegisteredModels
&& this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) {
const data = this.modelBrowsePage?.registeredModelsComponent?.data;
registeredModel = data && data.length > 0 ? data[0] : undefined;
} else {
const artifact = await this.getModelFileName();
modelFilePath = artifact?.filePath;
}
await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data);
return true;
} catch (error) {
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
return false;
}
}
private async onClose(): Promise<void> {
const artifact = await this.getModelFileName();
if (artifact) {
artifact.close();
}
await this.wizardView?.disposePages();
}
/**
* Refresh the pages
*/
public async refresh(): Promise<void> {
await this.wizardView?.refresh();
}
}

View File

@@ -0,0 +1,213 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
import { DatabaseTable } from '../../prediction/interfaces';
/**
* View to render filters to pick an azure resource
*/
export class TableSelectionComponent extends ModelViewBase implements IDataComponent<DatabaseTable> {
private _form: azdata.FormContainer | undefined;
private _databases: azdata.DropDownComponent | undefined;
private _selectedTableName: string = '';
private _tables: azdata.DropDownComponent | undefined;
private _dbNames: string[] = [];
private _tableNames: DatabaseTable[] = [];
private _dbTableComponent: azdata.FlexContainer | undefined;
private tableMaxLength = this.componentMaxLength * 2 + 70;
private _onSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onSelectedChanged: vscode.Event<void> = this._onSelectedChanged.event;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _editable: boolean) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._databases = modelBuilder.dropDown().withProperties({
width: this.componentMaxLength,
editable: this._editable,
fireOnTextChange: this._editable
}).component();
this._tables = modelBuilder.dropDown().withProperties({
width: this.componentMaxLength,
editable: this._editable,
fireOnTextChange: this._editable
}).component();
this._databases.onValueChanged(async () => {
await this.onDatabaseSelected();
});
this._tables.onValueChanged(async (value) => {
// There's an issue with dropdown doesn't set the value in editable mode. this is the workaround
if (this._tables && value) {
this._selectedTableName = this._editable ? value : value.selected;
}
await this.onTableSelected();
});
const databaseForm = modelBuilder.formContainer().withFormItems([{
title: constants.columnDatabase,
component: this._databases,
}]).withLayout({
padding: '0px'
}).component();
const tableForm = modelBuilder.formContainer().withFormItems([{
title: constants.columnTable,
component: this._tables
}]).withLayout({
padding: '0px'
}).component();
this._dbTableComponent = modelBuilder.flexContainer().withItems([
databaseForm,
tableForm
], {
flex: '0 0 auto',
CSSStyles: {
'align-items': 'flex-start'
}
}).withLayout({
flexFlow: 'row',
justifyContent: 'space-between',
width: this.tableMaxLength
}).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._dbTableComponent
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._databases && this._tables) {
formBuilder.addFormItems([{
title: constants.databaseName,
component: this._databases
}, {
title: constants.tableName,
component: this._tables
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._databases && this._tables) {
formBuilder.removeFormItem({
title: constants.databaseName,
component: this._databases
});
formBuilder.removeFormItem({
title: constants.tableName,
component: this._tables
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._dbTableComponent;
}
/**
* Returns selected data
*/
public get data(): DatabaseTable | undefined {
return this.databaseTable;
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
this._dbNames = await this.listDatabaseNames();
if (this._databases && this._dbNames && this._dbNames.length > 0) {
this._databases.values = this._dbNames;
if (this.importTable) {
this._databases.value = this.importTable.databaseName;
} else {
this._databases.value = this._dbNames[0];
}
}
await this.onDatabaseSelected();
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
private async onDatabaseSelected(): Promise<void> {
this._tableNames = await this.listTableNames(this.databaseName || '');
if (this._tables && this._tableNames && this._tableNames.length > 0) {
this._tables.values = this._tableNames.map(t => this.getTableFullName(t));
if (this.importTable) {
const selectedTable = this._tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema);
if (selectedTable) {
this._selectedTableName = this.getTableFullName(selectedTable);
this._tables.value = this.getTableFullName(selectedTable);
} else {
this._selectedTableName = this._editable ? this.getTableFullName(this.importTable) : this.getTableFullName(this._tableNames[0]);
}
} else {
this._selectedTableName = this.getTableFullName(this._tableNames[0]);
}
this._tables.value = this._selectedTableName;
} else if (this._tables) {
this._tables.values = [];
this._tables.value = '';
}
await this.onTableSelected();
}
private getTableFullName(table: DatabaseTable): string {
return `${table.schema}.${table.tableName}`;
}
private async onTableSelected(): Promise<void> {
this._onSelectedChanged.fire();
}
private get databaseName(): string | undefined {
return <string>this._databases?.value;
}
private get databaseTable(): DatabaseTable {
let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._selectedTableName);
if (!selectedItem) {
const value = this._selectedTableName;
const parts = value ? value.split('.') : undefined;
selectedItem = {
databaseName: this.databaseName,
tableName: parts && parts.length > 1 ? parts[1] : value,
schema: parts && parts.length > 1 ? parts[0] : 'dbo',
};
}
return {
databaseName: this.databaseName,
tableName: selectedItem?.tableName,
schema: selectedItem?.schema
};
}
}