Machine Learning Model Registry - Iteration1 (#9105)

* Machine learning services extension - model registration wizard
This commit is contained in:
Leila Lali
2020-02-26 09:19:48 -08:00
committed by GitHub
parent 067fcc8dfb
commit ff207859d6
46 changed files with 3990 additions and 210 deletions

View File

@@ -0,0 +1,103 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
import { AzureModelsTable } from './azureModelsTable';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent, AzureModelResource } from '../interfaces';
export class AzureModelsComponent extends ModelViewBase implements IPageView, IDataComponent<AzureModelResource> {
public azureModelsTable: AzureModelsTable | undefined;
public azureFilterComponent: AzureResourceFilterComponent | undefined;
private _loader: azdata.LoadingComponent | undefined;
private _form: azdata.FormContainer | undefined;
/**
* Component to render a view to pick an azure model
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this);
this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this);
this._loader = modelBuilder.loadingComponent()
.withItem(this.azureModelsTable.component)
.withProperties({
loading: true
}).component();
this.azureFilterComponent.onWorkspacesSelected(async () => {
await this.onLoading();
await this.azureModelsTable?.loadData(this.azureFilterComponent?.data);
await this.onLoaded();
});
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureModelFilter,
component: this.azureFilterComponent.component
}, {
title: constants.azureModels,
component: this._loader
}]).component();
return this._form;
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
private async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Loads the data in the components
*/
public async loadData(): Promise<void> {
await this.azureFilterComponent?.loadData();
}
/**
* Returns selected data
*/
public get data(): AzureModelResource | undefined {
return Object.assign({}, this.azureFilterComponent?.data, {
model: this.azureModelsTable?.data
});
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
/**
* Returns the title of the page
*/
public get title(): string {
return constants.azureModelsTitle;
}
}

View File

@@ -0,0 +1,141 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { WorkspaceModel } from '../../modelManagement/interfaces';
import { IDataComponent, AzureWorkspaceResource } from '../interfaces';
/**
* View to render azure models in a table
*/
export class AzureModelsTable extends ModelViewBase implements IDataComponent<WorkspaceModel> {
private _table: azdata.DeclarativeTableComponent;
private _selectedModelId: any;
private _models: WorkspaceModel[] | undefined;
/**
* Creates a view to render azure models in a table
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._table = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Id
displayName: constants.modeIld,
ariaLabel: constants.modeIld,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
}
public get component(): azdata.DeclarativeTableComponent {
return this._table;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadData(workspaceResource?: AzureWorkspaceResource | undefined): Promise<void> {
if (this._table && workspaceResource) {
this._models = await this.listAzureModels(workspaceResource);
let tableData: any[][] = [];
if (this._models) {
tableData = tableData.concat(this._models.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
}
private createTableRow(model: WorkspaceModel): any[] {
if (this._modelBuilder) {
let selectModelButton = this._modelBuilder.radioButton().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
selectModelButton.onDidClick(() => {
this._selectedModelId = model.id;
});
return [model.id, model.name, selectModelButton];
}
return [];
}
/**
* Returns selected data
*/
public get data(): WorkspaceModel | undefined {
if (this._models && this._selectedModelId) {
return this._models.find(x => x.id === this._selectedModelId);
}
return undefined;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,168 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { azureResource } from '../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import * as constants from '../../common/constants';
import { AzureWorkspaceResource, IDataComponent } from '../interfaces';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 200;
export class AzureResourceFilterComponent extends ModelViewBase implements IDataComponent<AzureWorkspaceResource> {
private _form: azdata.FormContainer;
private _accounts: azdata.DropDownComponent;
private _subscriptions: azdata.DropDownComponent;
private _groups: azdata.DropDownComponent;
private _workspaces: azdata.DropDownComponent;
private _azureAccounts: azdata.Account[] = [];
private _azureSubscriptions: azureResource.AzureResourceSubscription[] = [];
private _azureGroups: azureResource.AzureResource[] = [];
private _azureWorkspaces: Workspace[] = [];
private _onWorkspacesSelected: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onWorkspacesSelected: vscode.Event<void> = this._onWorkspacesSelected.event;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._accounts = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._subscriptions = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._groups = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._workspaces = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._accounts.onValueChanged(async () => {
await this.onAccountSelected();
});
this._subscriptions.onValueChanged(async () => {
await this.onSubscriptionSelected();
});
this._groups.onValueChanged(async () => {
await this.onGroupSelected();
});
this._workspaces.onValueChanged(async () => {
await this.onWorkspaceSelected();
});
this._form = this._modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._accounts
}, {
title: constants.azureSubscription,
component: this._subscriptions
}, {
title: constants.azureGroup,
component: this._groups
}, {
title: constants.azureModelWorkspace,
component: this._workspaces
}]).component();
}
/**
* Returns the created component
*/
public get component(): azdata.Component {
return this._form;
}
/**
* Returns selected data
*/
public get data(): AzureWorkspaceResource | undefined {
return {
account: this.account,
subscription: this.subscription,
group: this.group,
workspace: this.workspace
};
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
this._azureAccounts = await this.listAzureAccounts();
if (this._azureAccounts && this._azureAccounts.length > 0) {
let values = this._azureAccounts.map(a => { return { displayName: a.displayInfo.displayName, name: a.key.accountId }; });
this._accounts.values = values;
this._accounts.value = values[0];
}
await this.onAccountSelected();
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
private async onAccountSelected(): Promise<void> {
this._azureSubscriptions = await this.listAzureSubscriptions(this.account);
if (this._azureSubscriptions && this._azureSubscriptions.length > 0) {
let values = this._azureSubscriptions.map(s => { return { displayName: s.name, name: s.id }; });
this._subscriptions.values = values;
this._subscriptions.value = values[0];
}
await this.onSubscriptionSelected();
}
private async onSubscriptionSelected(): Promise<void> {
this._azureGroups = await this.listAzureGroups(this.account, this.subscription);
if (this._azureGroups && this._azureGroups.length > 0) {
let values = this._azureGroups.map(s => { return { displayName: s.name, name: s.id }; });
this._groups.values = values;
this._groups.value = values[0];
}
await this.onGroupSelected();
}
private async onGroupSelected(): Promise<void> {
this._azureWorkspaces = await this.listWorkspaces(this.account, this.subscription, this.group);
if (this._azureWorkspaces && this._azureWorkspaces.length > 0) {
let values = this._azureWorkspaces.map(s => { return { displayName: s.name || '', name: s.id || '' }; });
this._workspaces.values = values;
this._workspaces.value = values[0];
}
this.onWorkspaceSelected();
}
private onWorkspaceSelected(): void {
this._onWorkspacesSelected.fire();
}
private get workspace(): Workspace | undefined {
return this._azureWorkspaces ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
}
private get account(): azdata.Account | undefined {
return this._azureAccounts ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
}
private get group(): azureResource.AzureResource | undefined {
return this._azureGroups ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
}
private get subscription(): azureResource.AzureResourceSubscription | undefined {
return this._azureSubscriptions ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
}
}

View File

@@ -0,0 +1,104 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import { ModelViewBase, RegisterModelEventName } from './modelViewBase';
import { CurrentModelsTable } from './currentModelsTable';
import { ApiWrapper } from '../../common/apiWrapper';
import { IPageView } from '../interfaces';
/**
* View to render current registered models
*/
export class CurrentModelsPage extends ModelViewBase implements IPageView {
private _tableComponent: azdata.DeclarativeTableComponent | undefined;
private _dataTable: CurrentModelsTable | undefined;
private _loader: azdata.LoadingComponent | undefined;
/**
*
* @param apiWrapper Creates new view
* @param parent page parent
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._dataTable = new CurrentModelsTable(this._apiWrapper, modelBuilder, this);
this._tableComponent = this._dataTable.component;
let registerButton = modelBuilder.button().withProperties({
label: constants.registerModelButton,
width: this.buttonMaxLength
}).component();
registerButton.onDidClick(async () => {
await this.sendDataRequest(RegisterModelEventName);
});
let formModel = modelBuilder.formContainer()
.withFormItems([{
title: '',
component: registerButton
}, {
component: this._tableComponent,
title: ''
}]).component();
this._loader = modelBuilder.loadingComponent()
.withItem(formModel)
.withProperties({
loading: true
}).component();
return this._loader;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._loader;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.onLoading();
try {
await this._dataTable?.refresh();
} catch (err) {
this.showErrorMessage(constants.getErrorMessage(err));
} finally {
await this.onLoaded();
}
}
/**
* returns the title of the page
*/
public get title(): string {
return constants.currentModelsTitle;
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
private async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
}

View File

@@ -0,0 +1,131 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { RegisteredModel } from '../../modelManagement/interfaces';
/**
* View to render registered models table
*/
export class CurrentModelsTable extends ModelViewBase {
private _table: azdata.DeclarativeTableComponent;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._table = this.registerComponent(this._modelBuilder);
}
/**
*
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Id
displayName: constants.modeIld,
ariaLabel: constants.modeIld,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
}
/**
* Returns the component
*/
public get component(): azdata.DeclarativeTableComponent {
return this._table;
}
/**
* Loads the data in the component
*/
public async loadData(): Promise<void> {
let models: RegisteredModel[] | undefined;
models = await this.listModels();
let tableData: any[][] = [];
if (models) {
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
private createTableRow(model: RegisteredModel): any[] {
if (this._modelBuilder) {
let editLanguageButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
width: 15,
height: 15
}).component();
editLanguageButton.onDidClick(() => {
});
return [model.id, model.name, editLanguageButton];
}
return [];
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,92 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
/**
* View to pick local models file
*/
export class LocalModelsComponent extends ModelViewBase implements IPageView, IDataComponent<string> {
private _form: azdata.FormContainer | undefined;
private _localPath: azdata.InputBoxComponent | undefined;
private _localBrowse: azdata.ButtonComponent | undefined;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._localPath = modelBuilder.inputBox().withProperties({
value: '',
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength
}).component();
this._localBrowse = modelBuilder.button().withProperties({
label: constants.browseModels,
width: this.browseButtonMaxLength,
CSSStyles: {
'text-align': 'end'
}
}).component();
this._localBrowse.onDidClick(async () => {
const filePath = await this.getLocalFilePath();
if (this._localPath) {
this._localPath.value = filePath;
}
});
let flexFilePathModel = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
}).withItems([
this._localPath, this._localBrowse]
).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: flexFilePathModel
}]).component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): string {
return this._localPath?.value || '';
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
/**
* Returns the page title
*/
public get title(): string {
return constants.localModelsTitle;
}
}

View File

@@ -0,0 +1,188 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
import { RegisteredModelsDialog } from './registeredModelsDialog';
import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName } from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { RegisterModelWizard } from './registerModelWizard';
import * as fs from 'fs';
import * as constants from '../../common/constants';
/**
* Model management UI controller
*/
export class ModelManagementController extends ControllerBase {
/**
* Creates new instance
*/
constructor(
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
private _registeredModelService: RegisteredModelService) {
super(apiWrapper);
}
/**
* Opens the dialog for model registration
* @param parent parent if the view is opened from another view
* @param controller controller
* @param apiWrapper apiWrapper
* @param root root folder path
*/
public async registerModel(parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
controller = controller || this;
apiWrapper = apiWrapper || this._apiWrapper;
root = root || this._root;
let view = new RegisterModelWizard(apiWrapper, root, parent);
controller.registerEvents(view);
// Open view
//
view.open();
return view;
}
/**
* Register events in the main view
* @param view main view
*/
public registerEvents(view: ModelViewBase): void {
// Register events
//
super.registerEvents(view);
view.on(ListAccountsEventName, async () => {
await this.executeAction(view, ListAccountsEventName, this.getAzureAccounts, this._amlService);
});
view.on(ListSubscriptionsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListSubscriptionsEventName, this.getAzureSubscriptions, this._amlService, azureArgs.account);
});
view.on(ListWorkspacesEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListWorkspacesEventName, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group);
});
view.on(ListGroupsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListGroupsEventName, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription);
});
view.on(ListAzureModelsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService
, azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace);
});
view.on(ListModelsEventName, async () => {
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService);
});
view.on(RegisterLocalModelEventName, async (arg) => {
let registerArgs = <RegisterLocalModelEventArgs>arg;
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, registerArgs.filePath);
view.refresh();
});
view.on(RegisterModelEventName, async () => {
await this.executeAction(view, RegisterModelEventName, this.registerModel, view, this, this._apiWrapper, this._root);
});
view.on(RegisterAzureModelEventName, async (arg) => {
let registerArgs = <RegisterAzureModelEventArgs>arg;
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
});
view.on(SourceModelSelectedEventName, () => {
view.refresh();
});
}
/**
* Opens the dialog for model management
*/
public async manageRegisteredModels(): Promise<ModelViewBase> {
let view = new RegisteredModelsDialog(this._apiWrapper, this._root);
// Register events
//
this.registerEvents(view);
// Open view
//
view.open();
return view;
}
private async getAzureAccounts(service: AzureModelRegistryService): Promise<azdata.Account[]> {
return await service.getAccounts();
}
private async getAzureSubscriptions(service: AzureModelRegistryService, account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[] | undefined> {
return await service.getSubscriptions(account);
}
private async getAzureGroups(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[] | undefined> {
return await service.getGroups(account, subscription);
}
private async getWorkspaces(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[] | undefined> {
if (!account || !subscription) {
return [];
}
return await service.getWorkspaces(account, subscription, group);
}
private async getRegisteredModels(registeredModelService: RegisteredModelService): Promise<RegisteredModel[]> {
return registeredModelService.getRegisteredModels();
}
private async getAzureModels(
service: AzureModelRegistryService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined): Promise<WorkspaceModel[]> {
if (!account || !subscription || !resourceGroup || !workspace) {
return [];
}
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: RegisteredModelService, filePath?: string): Promise<void> {
if (filePath) {
await service.registerLocalModel(filePath);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async registerAzureModel(
azureService: AzureModelRegistryService,
service: RegisteredModelService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined): Promise<void> {
if (!account || !subscription || !resourceGroup || !workspace || !model) {
throw Error(constants.invalidAzureResourceError);
}
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
await service.registerLocalModel(filePath);
await fs.promises.unlink(filePath);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
}

View File

@@ -0,0 +1,102 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SourceModelSelectedEventName } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
export enum ModelSourceType {
Local,
Azure
}
/**
* View tp pick model source
*/
export class ModelSourcesComponent extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
private _form: azdata.FormContainer | undefined;
private _amlModel: azdata.RadioButtonComponent | undefined;
private _localModel: azdata.RadioButtonComponent | undefined;
private _isLocalModel: boolean = true;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._localModel = modelBuilder.radioButton()
.withProperties({
value: 'local',
name: 'modelLocation',
label: constants.localModelSource,
checked: true
}).component();
this._amlModel = modelBuilder.radioButton()
.withProperties({
value: 'aml',
name: 'modelLocation',
label: constants.azureModelSource,
}).component();
this._localModel.onDidClick(() => {
this._isLocalModel = true;
this.sendRequest(SourceModelSelectedEventName);
});
this._amlModel.onDidClick(() => {
this._isLocalModel = false;
this.sendRequest(SourceModelSelectedEventName);
});
let flex = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'column',
justifyContent: 'space-between'
}).withItems([
this._localModel, this._amlModel]
).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.modelSourcesTitle,
component: flex
}]).component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelSourceType {
return this._isLocalModel ? ModelSourceType.Local : ModelSourceType.Azure;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelSourcesTitle;
}
}

View File

@@ -0,0 +1,147 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { ViewBase } from '../viewBase';
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
}
export interface RegisterAzureModelEventArgs extends AzureModelResource {
model?: WorkspaceModel;
}
export interface RegisterLocalModelEventArgs extends AzureResourceEventArgs {
filePath?: string;
}
// Event names
//
export const ListModelsEventName = 'listModels';
export const ListAzureModelsEventName = 'listAzureModels';
export const ListAccountsEventName = 'listAccounts';
export const ListSubscriptionsEventName = 'listSubscriptions';
export const ListGroupsEventName = 'listGroups';
export const ListWorkspacesEventName = 'listWorkspaces';
export const RegisterLocalModelEventName = 'registerLocalModel';
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const RegisterModelEventName = 'registerModel';
export const SourceModelSelectedEventName = 'sourceModelSelected';
/**
* Base class for all model management views
*/
export abstract class ModelViewBase extends ViewBase {
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
super(apiWrapper, root, parent);
}
protected getEventNames(): string[] {
return super.getEventNames().concat([ListModelsEventName,
ListAzureModelsEventName,
ListAccountsEventName,
ListSubscriptionsEventName,
ListGroupsEventName,
ListWorkspacesEventName,
RegisterLocalModelEventName,
RegisterAzureModelEventName,
RegisterModelEventName,
SourceModelSelectedEventName]);
}
/**
* Parent view
*/
public get parent(): ModelViewBase | undefined {
return this._parent ? <ModelViewBase>this._parent : undefined;
}
/**
* list azure models
*/
public async listAzureModels(workspaceResource: AzureWorkspaceResource): Promise<WorkspaceModel[]> {
const args: AzureResourceEventArgs = workspaceResource;
return await this.sendDataRequest(ListAzureModelsEventName, args);
}
/**
* list registered models
*/
public async listModels(): Promise<RegisteredModel[]> {
return await this.sendDataRequest(ListModelsEventName);
}
/**
* lists azure accounts
*/
public async listAzureAccounts(): Promise<azdata.Account[]> {
return await this.sendDataRequest(ListAccountsEventName);
}
/**
* lists azure subscriptions
* @param account azure account
*/
public async listAzureSubscriptions(account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[]> {
const args: AzureResourceEventArgs = {
account: account
};
return await this.sendDataRequest(ListSubscriptionsEventName, args);
}
/**
* registers local model
* @param localFilePath local file path
*/
public async registerLocalModel(localFilePath: string | undefined): Promise<void> {
const args: RegisterLocalModelEventArgs = {
filePath: localFilePath
};
return await this.sendDataRequest(RegisterLocalModelEventName, args);
}
/**
* registers azure model
* @param args azure resource
*/
public async registerAzureModel(args: RegisterAzureModelEventArgs | undefined): Promise<void> {
return await this.sendDataRequest(RegisterAzureModelEventName, args);
}
/**
* list resource groups
* @param account azure account
* @param subscription azure subscription
*/
public async listAzureGroups(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[]> {
const args: AzureResourceEventArgs = {
account: account,
subscription: subscription
};
return await this.sendDataRequest(ListGroupsEventName, args);
}
/**
* lists azure workspaces
* @param account azure account
* @param subscription azure subscription
* @param group azure resource group
*/
public async listWorkspaces(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[]> {
const args: AzureResourceEventArgs = {
account: account,
subscription: subscription,
group: group
};
return await this.sendDataRequest(ListWorkspacesEventName, args);
}
}

View File

@@ -0,0 +1,96 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import * as constants from '../../common/constants';
import { WizardView } from '../wizardView';
/**
* Wizard to register a model
*/
export class RegisterModelWizard extends ModelViewBase {
public modelResources: ModelSourcesComponent | undefined;
public localModelsComponent: LocalModelsComponent | undefined;
public azureModelsComponent: AzureModelsComponent | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
constructor(
apiWrapper: ApiWrapper,
root: string,
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
}
/**
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.registerModelWizardTitle, [this.modelResources, this.localModelsComponent]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.azureRegisterModel;
wizard.generateScriptButton.hidden = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
if (pageInfo.newPage === undefined) {
await this.registerModel();
if (this._parentView) {
this._parentView?.refresh();
}
return true;
}
return true;
});
wizard.open();
}
private async registerModel(): Promise<boolean> {
try {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
await this.registerLocalModel(this.localModelsComponent.data);
} else {
await this.registerAzureModel(this.azureModelsComponent?.data);
}
this.showInfoMessage(constants.modelRegisteredSuccessfully);
return true;
} catch (error) {
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
return false;
}
}
private loadPages(): void {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
this.wizardView?.addWizardPage(this.localModelsComponent, 1);
} else if (this.azureModelsComponent) {
this.wizardView?.addWizardPage(this.azureModelsComponent, 1);
}
}
/**
* Refresh the pages
*/
public async refresh(): Promise<void> {
this.loadPages();
this.wizardView?.refresh();
}
}

View File

@@ -0,0 +1,54 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { CurrentModelsPage } from './currentModelsPage';
import { ModelViewBase } from './modelViewBase';
import * as constants from '../../common/constants';
import { ApiWrapper } from '../../common/apiWrapper';
import { DialogView } from '../dialogView';
/**
* Dialog to render registered model views
*/
export class RegisteredModelsDialog extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string) {
super(apiWrapper, root);
this.dialogView = new DialogView(this._apiWrapper);
}
public dialogView: DialogView;
public currentLanguagesTab: CurrentModelsPage | undefined;
/**
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this);
let dialog = this.dialogView.createDialog('', [this.currentLanguagesTab]);
this.mainViewPanel = dialog;
dialog.okButton.hidden = true;
dialog.cancelButton.label = constants.extLangDoneButtonText;
dialog.registerCloseValidator(() => {
return false; // Blocks Enter key from closing dialog.
});
this._apiWrapper.openDialog(dialog);
}
/**
* Resets the tabs for given provider Id
*/
public async refresh(): Promise<void> {
if (this.dialogView) {
this.dialogView.refresh();
}
}
}