ML extension - Improving predict parameter mapping experience (#10264)

This commit is contained in:
Leila Lali
2020-05-10 18:10:17 -07:00
committed by GitHub
parent f6e7b56946
commit 3d2d791f18
44 changed files with 782 additions and 388 deletions

View File

@@ -115,7 +115,8 @@ const modelParameters: ModelParameters = {
};
describe('Model Controller', () => {
it('Should open deploy model wizard successfully ', async function (): Promise<void> {
it('Should open import model wizard successfully ', async function (): Promise<void> {
let testContext = createContext();
@@ -125,16 +126,24 @@ describe('Model Controller', () => {
tableName: 'table',
schema: 'dbo'
}));
testContext.deployModelService.setup(x => x.storeRecentImportTable(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
testContext.deployModelService.setup(x => x.verifyConfigTable(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
testContext.deployModelService.setup(x => x.deployLocalModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.deployModelService.setup(x => x.deleteModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('path'));
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames));
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames));
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
testContext.azureModelService.setup(x => x.signInToAzure()).returns(() => Promise.resolve());
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('path'));
const view = await controller.registerModel(undefined);
const view = await controller.importModel(undefined);
should.notEqual(view, undefined);
});
@@ -161,7 +170,10 @@ describe('Model Controller', () => {
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames));
testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(true));
testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters));
testContext.deployModelService.setup(x => x.verifyConfigTable(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
testContext.deployModelService.setup(x => x.installDependencies()).returns(() => Promise.resolve());
testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
@@ -169,6 +181,17 @@ describe('Model Controller', () => {
should.notEqual(view, undefined);
});
it('Should show error message if onnx is not supported ', async function (): Promise<void> {
let testContext = createContext();
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(false));
testContext.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
const view = await controller.predictModel();
should.equal(view, undefined);
testContext.apiWrapper.verify(x => x.showErrorMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Should open edit model dialog successfully ', async function (): Promise<void> {
let testContext = createContext();
testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
@@ -199,4 +222,5 @@ describe('Model Controller', () => {
should.notEqual(view, undefined);
});
});

View File

@@ -71,21 +71,21 @@ describe('Azure Models Component', () => {
name: 'model'
}
];
parent.on(ListAccountsEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
parent.on(ListAccountsEventName, (args) => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { inputArgs: args, data: accounts });
});
parent.on(ListSubscriptionsEventName, () => {
parent.on(ListSubscriptionsEventName, (args) => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { inputArgs: args, data: subscriptions });
});
parent.on(ListGroupsEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
parent.on(ListGroupsEventName, (args) => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { inputArgs: args, data: groups });
});
parent.on(ListWorkspacesEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
parent.on(ListWorkspacesEventName, (args) => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { inputArgs: args, data: workspaces });
});
parent.on(ListAzureModelsEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
parent.on(ListAzureModelsEventName, (args) => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { inputArgs: args, data: models });
});
await view.refresh();
testContext.onClick.fire(true);

View File

@@ -9,7 +9,7 @@ import 'mocha';
import { createContext } from './utils';
import {
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType, VerifyImportTableEventName
}
from '../../../views/models/modelViewBase';
import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces';
@@ -136,42 +136,45 @@ describe('Predict Wizard', () => {
]
};
view.on(ListModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
view.on(ListModelsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { inputArgs: args, data: localModels });
});
view.on(ListAccountsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
view.on(ListAccountsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { inputArgs: args, data: accounts });
});
view.on(ListSubscriptionsEventName, () => {
view.on(ListSubscriptionsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { inputArgs: args, data: subscriptions });
});
view.on(ListGroupsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
view.on(ListGroupsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { inputArgs: args, data: groups });
});
view.on(ListWorkspacesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
view.on(ListWorkspacesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { inputArgs: args, data: workspaces });
});
view.on(ListAzureModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
view.on(ListAzureModelsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { inputArgs: args, data: models });
});
view.on(ListDatabaseNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames });
view.on(ListDatabaseNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { inputArgs: args, data: dbNames });
});
view.on(ListTableNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames });
view.on(ListTableNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { inputArgs: args, data: tableNames });
});
view.on(ListColumnNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames });
view.on(ListColumnNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { inputArgs: args, data: columnNames });
});
view.on(LoadModelParametersEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters });
view.on(LoadModelParametersEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { inputArgs: args, data: modelParameters });
});
view.on(DownloadAzureModelEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' });
view.on(DownloadAzureModelEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { inputArgs: args, data: 'path' });
});
view.on(DownloadRegisteredModelEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
view.on(DownloadRegisteredModelEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { inputArgs: args, data: 'path' });
});
view.on(VerifyImportTableEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(VerifyImportTableEventName), { inputArgs: args, data: view.importTable });
});
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;

View File

@@ -7,24 +7,78 @@ import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase';
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName, VerifyImportTableEventName } from '../../../views/models/modelViewBase';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard';
import { DatabaseTable } from '../../../prediction/interfaces';
let accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
let subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
let groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3'
}
];
let workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
let models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
let localModels: ImportedModel[] = [
{
id: 1,
modelName: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
let importTable: DatabaseTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
describe('Register Model Wizard', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'table',
schema: 'dbo'
};
view.importTable = importTable;
await view.open();
should.notEqual(view.wizardView, undefined);
should.notEqual(view.modelSourcePage, undefined);
@@ -34,98 +88,56 @@ describe('Register Model Wizard', () => {
let testContext = createContext();
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
view.importTable = importTable;
await view.open();
let accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
let subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
let groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3'
}
];
let workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
let models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
let localModels: ImportedModel[] = [
{
id: 1,
modelName: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
view.on(ListModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
});
view.on(ListDatabaseNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [
'db', 'db1'
] });
});
view.on(ListTableNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [
'tb', 'tb1'
] });
});
view.on(ListAccountsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
});
view.on(ListSubscriptionsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
});
view.on(ListGroupsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
});
view.on(ListWorkspacesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
});
view.on(ListAzureModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
});
setEvents(view);
await view.refresh();
should.notEqual(view.modelBrowsePage, undefined);
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
await view.modelBrowsePage.refresh();
should.equal(view.modelBrowsePage.modelSourceType, ModelSourceType.Azure);
}
await view.refresh();
should.notEqual(view.azureModelsComponent?.data ,undefined);
should.notEqual(view.azureModelsComponent?.data, undefined);
should.notEqual(view.localModelsComponent?.data, undefined);
});
function setEvents(view: ImportModelWizard): void {
view.on(ListModelsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { inputArgs: args, data: localModels });
});
view.on(ListDatabaseNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), {
inputArgs: args, data: [
'db', 'db1'
]
});
});
view.on(ListTableNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), {
inputArgs: args, data: [
'tb', 'tb1'
]
});
});
view.on(ListAccountsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { inputArgs: args, data: accounts });
});
view.on(ListSubscriptionsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { inputArgs: args, data: subscriptions });
});
view.on(ListGroupsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { inputArgs: args, data: groups });
});
view.on(ListWorkspacesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { inputArgs: args, data: workspaces });
});
view.on(ListAzureModelsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { inputArgs: args, data: models });
});
view.on(VerifyImportTableEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(VerifyImportTableEventName), { inputArgs: args, data: view.importTable });
});
}
});