ML - Target import table selectable by user (#10071)

ML - Target import table selectable by user
This commit is contained in:
Leila Lali
2020-04-21 08:02:48 -07:00
committed by GitHub
parent 4f1d4276a0
commit a34feb4448
30 changed files with 1172 additions and 317 deletions

View File

@@ -17,6 +17,8 @@ import * as path from 'path';
import * as os from 'os';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as fs from 'fs';
import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent';
import { DatabaseTable } from '../../prediction/interfaces';
interface TestContext {
@@ -24,6 +26,8 @@ interface TestContext {
config: TypeMoq.IMock<Config>;
queryRunner: TypeMoq.IMock<QueryRunner>;
modelClient: TypeMoq.IMock<ModelPythonClient>;
recentModels: TypeMoq.IMock<ModelConfigRecent>;
importTable: DatabaseTable;
}
function createContext(): TestContext {
@@ -32,7 +36,13 @@ function createContext(): TestContext {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
modelClient: TypeMoq.Mock.ofType(ModelPythonClient),
recentModels: TypeMoq.Mock.ofType(ModelConfigRecent),
importTable: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
};
}
@@ -40,14 +50,20 @@ describe('DeployedModelService', () => {
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
const testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
let importTable: DatabaseTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
await should(service.getDeployedModels()).rejected();
testContext.modelClient.object,
testContext.recentModels.object);
await should(service.getDeployedModels(importTable)).rejected();
});
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
@@ -61,7 +77,9 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
}
];
const result = {
@@ -106,12 +124,13 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
const actual = await service.getDeployedModels();
const actual = await service.getDeployedModels(testContext.importTable);
should.deepEqual(actual, expected);
});
@@ -140,7 +159,8 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
const actual = await service.loadModelParameters('');
should.deepEqual(actual, expected);
});
@@ -158,7 +178,8 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
const result = {
rowCount: 1,
@@ -177,7 +198,8 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
@@ -198,7 +220,8 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
const row = [
{
@@ -247,15 +270,17 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => {
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0), TypeMoq.It.isAny())).returns(() => {
deployed = true;
return Promise.resolve(result);
});
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
});
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
@@ -264,7 +289,7 @@ describe('DeployedModelService', () => {
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, 'test');
await should(service.deployLocalModel(tempFilePath, model)).resolved();
await should(service.deployLocalModel(tempFilePath, model, testContext.importTable)).resolved();
}
finally {
await utils.deleteFile(tempFilePath);
@@ -273,31 +298,28 @@ describe('DeployedModelService', () => {
it('getConfigureQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const dbName = 'curre[n]tDb';
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo';
const expected = `
IF NOT EXISTS (
SELECT [name]
FROM sys.databases
WHERE [name] = N'd[]b'
)
CREATE DATABASE [d[[]]b]
GO
USE [d[[]]b]
IF EXISTS
( SELECT [t.name], [s.name]
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE [t.name] = 'ta[b]le'
AND [s.name] = 'dbo'
WHERE t.name = 'ta[b]le'
AND s.name = 'dbo'
)
BEGIN
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_name')
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_name] [varchar](256) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_content')
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_content] [varbinary](max) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
@@ -315,23 +337,22 @@ describe('DeployedModelService', () => {
CREATE TABLE [dbo].[ta[[b]]le](
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
[artifact_name] [varchar](256) NOT NULL,
[group_path] [varchar](256) NULL,
[artifact_content] [varbinary](max) NOT NULL,
[artifact_initial_size] [bigint] NULL,
[name] [varchar](256) NULL,
[version] [varchar](256) NULL,
[created] [datetime] NULL,
[description] [varchar](256) NULL,
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
CONSTRAINT [ta[[b]]le_artifact_pk] PRIMARY KEY CLUSTERED
(
[artifact_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
END
`;
const actual = service.getConfigureQuery(dbName);
should.equal(actual.indexOf(expected) > 0, true);
const actual = service.getConfigureTableQuery(testContext.importTable);
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
});
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
@@ -340,23 +361,23 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo';
const expected = `
SELECT artifact_id, artifact_name, name, description, version, created
FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
const actual = service.getDeployedModelsQuery();
const actual = service.getDeployedModelsQuery(testContext.importTable);
should.deepEqual(expected, actual);
});
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const dbName = 'curre[n]tDb';
const model: RegisteredModel =
{
id: 1,
@@ -364,28 +385,27 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
const expected = `
Insert into [dbo].[ta[[b]]le]
(artifact_name, group_path, artifact_content, name, version, description)
Insert into [dbo].[tb]
(artifact_name, artifact_content, name, version, description)
values (
'name1',
'ADS',
,
'title1',
'1.1',
'desc1')`;
const actual = service.getInsertModelQuery(dbName, model);
const actual = service.getInsertModelQuery(model, testContext.importTable);
should.equal(actual.indexOf(expected) > 0, true);
});
@@ -398,17 +418,19 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
model.table = {
databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo'
};
const expected = `
SELECT artifact_content
FROM [d[[]]b].[dbo].[ta[[b]]le]

View File

@@ -8,7 +8,6 @@ import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { Config } from '../../configurations/config';
import { PredictService } from '../../prediction/predictService';
import { QueryRunner } from '../../common/queryRunner';
import { RegisteredModel } from '../../modelManagement/interfaces';
@@ -22,7 +21,7 @@ import * as fs from 'fs';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
config: TypeMoq.IMock<Config>;
importTable: DatabaseTable;
queryRunner: TypeMoq.IMock<QueryRunner>;
}
@@ -30,7 +29,11 @@ function createContext(): TestContext {
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
importTable: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
},
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
};
}
@@ -49,8 +52,7 @@ describe('PredictService', () => {
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object,
testContext.config.object);
testContext.queryRunner.object);
const actual = await service.getDatabaseList();
should.deepEqual(actual, expected);
});
@@ -102,8 +104,7 @@ describe('PredictService', () => {
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object,
testContext.config.object);
testContext.queryRunner.object);
const actual = await service.getTableList('db1');
should.deepEqual(actual, expected);
});
@@ -160,8 +161,7 @@ describe('PredictService', () => {
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object,
testContext.config.object);
testContext.queryRunner.object);
const actual = await service.getTableColumnsList(table);
should.deepEqual(actual, expected);
});
@@ -201,13 +201,13 @@ describe('PredictService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object,
testContext.config.object);
testContext.queryRunner.object);
const document: vscode.TextDocument = {
uri: vscode.Uri.parse('file:///usr/home'),
@@ -270,8 +270,7 @@ describe('PredictService', () => {
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object,
testContext.config.object);
testContext.queryRunner.object);
const document: vscode.TextDocument = {
uri: vscode.Uri.parse('file:///usr/home'),

View File

@@ -34,6 +34,6 @@ describe('Dashboard widget', () => {
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '');
dashboard.register();
testContext.onClick.fire();
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atMostOnce());
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce());
});
});

View File

@@ -0,0 +1,170 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import * as TypeMoq from 'typemoq';
import 'mocha';
import { createContext } from './utils';
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { ModelManagementController } from '../../../views/models/modelManagementController';
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
const accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
const subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
const groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3'
}
];
const workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
const models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
const localModels: RegisteredModel[] = [
{
id: 1,
artifactName: 'model',
title: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
const dbNames: string[] = [
'db1',
'db2'
];
const tableNames: DatabaseTable[] = [
{
databaseName: 'db1',
schema: 'dbo',
tableName: 'tb1'
},
{
databaseName: 'db1',
tableName: 'tb2',
schema: 'dbo'
}
];
const columnNames: TableColumn[] = [
{
columnName: 'c1',
dataType: 'int'
},
{
columnName: 'c2',
dataType: 'varchar'
}
];
const modelParameters: ModelParameters = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
}
]
};
describe('Model Controller', () => {
it('Should open deploy model wizard successfully ', async function (): Promise<void> {
let testContext = createContext();
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
databaseName: 'db',
tableName: 'table',
schema: 'dbo'
}));
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames));
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames));
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
const view = await controller.registerModel(undefined);
should.notEqual(view, undefined);
});
it('Should open predict wizard successfully ', async function (): Promise<void> {
let testContext = createContext();
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
databaseName: 'db',
tableName: 'table',
schema: 'dbo'
}));
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve([
'db', 'db1'
]));
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve([
{ tableName: 'tb', databaseName: 'db', schema: 'dbo' }
]));
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames));
testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters));
testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
const view = await controller.predictModel();
should.notEqual(view, undefined);
});
});

View File

@@ -34,6 +34,11 @@ describe('Predict Wizard', () => {
let testContext = createContext();
let view = new PredictWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
await view.open();
let accounts: azdata.Account[] = [
{
@@ -79,7 +84,12 @@ describe('Predict Wizard', () => {
{
id: 1,
artifactName: 'model',
title: 'model'
title: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
const dbNames: string[] = [

View File

@@ -7,21 +7,25 @@ import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType } from '../../../views/models/modelViewBase';
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase';
import { RegisteredModel } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard';
import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard';
describe('Register Model Wizard', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'table',
schema: 'dbo'
};
await view.open();
await view.refresh();
should.notEqual(view.wizardView, undefined);
should.notEqual(view.modelSourcePage, undefined);
});
@@ -29,7 +33,12 @@ describe('Register Model Wizard', () => {
it('Should load data successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
await view.open();
let accounts: azdata.Account[] = [
{
@@ -75,12 +84,27 @@ describe('Register Model Wizard', () => {
{
id: 1,
artifactName: 'model',
title: 'model'
title: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
view.on(ListModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
});
view.on(ListDatabaseNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [
'db', 'db1'
] });
});
view.on(ListTableNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [
'tb', 'tb1'
] });
});
view.on(ListAccountsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
});

View File

@@ -6,7 +6,7 @@
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog';
import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog';
import { ListModelsEventName } from '../../../views/models/modelViewBase';
import { RegisteredModel } from '../../../modelManagement/interfaces';
import { ViewBase } from '../../../views/viewBase';
@@ -15,7 +15,7 @@ describe('Registered Models Dialog', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new RegisteredModelsDialog(testContext.apiWrapper.object, '');
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
view.open();
should.notEqual(view.dialogView, undefined);
@@ -25,13 +25,18 @@ describe('Registered Models Dialog', () => {
it('Should load data successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new RegisteredModelsDialog(testContext.apiWrapper.object, '');
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
view.open();
let models: RegisteredModel[] = [
{
id: 1,
artifactName: 'model',
title: ''
title: '',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
view.on(ListModelsEventName, () => {

View File

@@ -9,11 +9,17 @@ import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../../common/apiWrapper';
import { createViewContext } from '../utils';
import { ModelViewBase } from '../../../views/models/modelViewBase';
import { AzureModelRegistryService } from '../../../modelManagement/azureModelRegistryService';
import { DeployedModelService } from '../../../modelManagement/deployedModelService';
import { PredictService } from '../../../prediction/predictService';
export interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
view: azdata.ModelView;
onClick: vscode.EventEmitter<any>;
azureModelService: TypeMoq.IMock<AzureModelRegistryService>;
deployModelService: TypeMoq.IMock<DeployedModelService>;
predictService: TypeMoq.IMock<PredictService>;
}
export class ParentDialog extends ModelViewBase {
@@ -36,6 +42,9 @@ export function createContext(): TestContext {
return {
apiWrapper: viewTestContext.apiWrapper,
view: viewTestContext.view,
onClick: viewTestContext.onClick
onClick: viewTestContext.onClick,
azureModelService: TypeMoq.Mock.ofType(AzureModelRegistryService),
deployModelService: TypeMoq.Mock.ofType(DeployedModelService),
predictService: TypeMoq.Mock.ofType(PredictService)
};
}

View File

@@ -31,6 +31,11 @@ export function createViewContext(): ViewTestContext {
let button: azdata.ButtonComponent = Object.assign({}, componentBase, {
onDidClick: onClick.event
});
let link: azdata.HyperlinkComponent = Object.assign({}, componentBase, {
onDidClick: onClick.event,
label: '',
url: ''
});
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
checked: true,
onDidClick: onClick.event
@@ -61,6 +66,11 @@ export function createViewContext(): ViewTestContext {
withProperties: () => buttonBuilder,
withValidation: () => buttonBuilder
};
let hyperLinkBuilder: azdata.ComponentBuilder<azdata.HyperlinkComponent> = {
component: () => link,
withProperties: () => hyperLinkBuilder,
withValidation: () => hyperLinkBuilder
};
let radioButtonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
component: () => radioButton,
withProperties: () => radioButtonBuilder,
@@ -72,7 +82,7 @@ export function createViewContext(): ViewTestContext {
withValidation: () => checkBoxBuilder
};
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
onTextChanged: undefined!,
onTextChanged: onClick.event!,
onEnterKeyPressed: undefined!,
value: ''
});
@@ -216,7 +226,7 @@ export function createViewContext(): ViewTestContext {
toolbarContainer: undefined!,
loadingComponent: () => loadingBuilder,
fileBrowserTree: undefined!,
hyperlink: undefined!,
hyperlink: () => hyperLinkBuilder,
tabbedPanel: undefined!,
separator: undefined!
}