Files
azuredatastudio/extensions/machine-learning/src/test/views/models/predictWizard.test.ts
2022-04-14 11:06:45 -07:00

212 lines
7.4 KiB
TypeScript

/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import {
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType, VerifyImportTableEventName
}
from '../../../views/models/modelViewBase';
import { ImportedModel, ModelParameters, WorkspaceModel } from '../../../modelManagement/interfaces';
import { azureResource } from 'azurecore';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase';
import { PredictWizard } from '../../../views/models/prediction/predictWizard';
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
describe('Predict Wizard', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new PredictWizard(testContext.apiWrapper.object, '');
await view.open();
should.notEqual(view.wizardView, undefined);
should.notEqual(view.modelSourcePage, undefined);
});
it('Should load data successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new PredictWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
await view.open();
let accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
let subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
let groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3',
subscription: {
id: 's1',
name: 's1'
}
}
];
let workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
let models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
let localModels: ImportedModel[] = [
{
id: 1,
modelName: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
const dbNames: string[] = [
'db1',
'db2'
];
const tableNames: DatabaseTable[] = [
{
databaseName: 'db1',
schema: 'dbo',
tableName: 'tb1'
},
{
databaseName: 'db1',
tableName: 'tb2',
schema: 'dbo'
}
];
const columnNames: TableColumn[] = [
{
columnName: 'c1',
dataType: 'int'
},
{
columnName: 'c2',
dataType: 'varchar'
}
];
const modelParameters: ModelParameters = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
}
]
};
view.on(ListModelsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { inputArgs: args, data: localModels });
});
view.on(ListAccountsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { inputArgs: args, data: accounts });
});
view.on(ListSubscriptionsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { inputArgs: args, data: subscriptions });
});
view.on(ListGroupsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { inputArgs: args, data: groups });
});
view.on(ListWorkspacesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { inputArgs: args, data: workspaces });
});
view.on(ListAzureModelsEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { inputArgs: args, data: models });
});
view.on(ListDatabaseNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { inputArgs: args, data: dbNames });
});
view.on(ListTableNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { inputArgs: args, data: tableNames });
});
view.on(ListColumnNamesEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { inputArgs: args, data: columnNames });
});
view.on(LoadModelParametersEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { inputArgs: args, data: modelParameters });
});
view.on(DownloadAzureModelEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { inputArgs: args, data: 'path' });
});
view.on(DownloadRegisteredModelEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { inputArgs: args, data: 'path' });
});
view.on(VerifyImportTableEventName, (args) => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(VerifyImportTableEventName), { inputArgs: args, data: view.importTable });
});
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
}
await view.refresh();
should.notEqual(view.azureModelsComponent?.data, undefined, 'Data from Azure component should not be null');
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
}
await view.refresh();
testContext.onClick.fire(undefined);
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels, 'Model source should be registered models');
should.notEqual(view.localModelsComponent?.data, undefined, 'Data from local model component should not be null');
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined, 'Data from registered model component should not be null');
if (view.modelBrowsePage?.registeredModelsComponent?.data) {
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1, 'Data from registered model component should not be empty');
}
should.notEqual(await view.getModelFileName(), undefined, 'Model file name should not be null');
await view.columnsSelectionPage?.onEnter();
await view.columnsSelectionPage?.inputColumnsComponent?.loadWithTable(tableNames[0]);
should.notEqual(view.columnsSelectionPage?.data, undefined, 'Data from column selection component should not be null');
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, `unexpected number of inputs. ${view.columnsSelectionPage?.data?.inputColumns?.length}`);
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length, `unexpected number of outputs. ${view.columnsSelectionPage?.data?.outputColumns?.length}`);
});
});