mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-07 01:25:38 -05:00
ML - Target import table selectable by user (#10071)
ML - Target import table selectable by user
This commit is contained in:
@@ -17,6 +17,8 @@ import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as fs from 'fs';
|
||||
import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent';
|
||||
import { DatabaseTable } from '../../prediction/interfaces';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
@@ -24,6 +26,8 @@ interface TestContext {
|
||||
config: TypeMoq.IMock<Config>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
modelClient: TypeMoq.IMock<ModelPythonClient>;
|
||||
recentModels: TypeMoq.IMock<ModelConfigRecent>;
|
||||
importTable: DatabaseTable;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
@@ -32,7 +36,13 @@ function createContext(): TestContext {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
|
||||
modelClient: TypeMoq.Mock.ofType(ModelPythonClient),
|
||||
recentModels: TypeMoq.Mock.ofType(ModelConfigRecent),
|
||||
importTable: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -40,14 +50,20 @@ describe('DeployedModelService', () => {
|
||||
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
let importTable: DatabaseTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
};
|
||||
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
await should(service.getDeployedModels()).rejected();
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
await should(service.getDeployedModels(importTable)).rejected();
|
||||
});
|
||||
|
||||
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
|
||||
@@ -61,7 +77,9 @@ describe('DeployedModelService', () => {
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
|
||||
}
|
||||
];
|
||||
const result = {
|
||||
@@ -106,12 +124,13 @@ describe('DeployedModelService', () => {
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
const actual = await service.getDeployedModels();
|
||||
const actual = await service.getDeployedModels(testContext.importTable);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
@@ -140,7 +159,8 @@ describe('DeployedModelService', () => {
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
const actual = await service.loadModelParameters('');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
@@ -158,7 +178,8 @@ describe('DeployedModelService', () => {
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
@@ -177,7 +198,8 @@ describe('DeployedModelService', () => {
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
@@ -198,7 +220,8 @@ describe('DeployedModelService', () => {
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
const row = [
|
||||
{
|
||||
@@ -247,15 +270,17 @@ describe('DeployedModelService', () => {
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => {
|
||||
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0), TypeMoq.It.isAny())).returns(() => {
|
||||
deployed = true;
|
||||
return Promise.resolve(result);
|
||||
});
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
||||
});
|
||||
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
@@ -264,7 +289,7 @@ describe('DeployedModelService', () => {
|
||||
try {
|
||||
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
await should(service.deployLocalModel(tempFilePath, model)).resolved();
|
||||
await should(service.deployLocalModel(tempFilePath, model, testContext.importTable)).resolved();
|
||||
}
|
||||
finally {
|
||||
await utils.deleteFile(tempFilePath);
|
||||
@@ -273,31 +298,28 @@ describe('DeployedModelService', () => {
|
||||
|
||||
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const dbName = 'curre[n]tDb';
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
|
||||
testContext.importTable.databaseName = 'd[]b';
|
||||
testContext.importTable.tableName = 'ta[b]le';
|
||||
testContext.importTable.schema = 'dbo';
|
||||
const expected = `
|
||||
IF NOT EXISTS (
|
||||
SELECT [name]
|
||||
FROM sys.databases
|
||||
WHERE [name] = N'd[]b'
|
||||
)
|
||||
CREATE DATABASE [d[[]]b]
|
||||
GO
|
||||
USE [d[[]]b]
|
||||
IF EXISTS
|
||||
( SELECT [t.name], [s.name]
|
||||
( SELECT t.name, s.name
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE [t.name] = 'ta[b]le'
|
||||
AND [s.name] = 'dbo'
|
||||
WHERE t.name = 'ta[b]le'
|
||||
AND s.name = 'dbo'
|
||||
)
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_name')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_name] [varchar](256) NOT NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_content')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_content] [varbinary](max) NOT NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
|
||||
@@ -315,23 +337,22 @@ describe('DeployedModelService', () => {
|
||||
CREATE TABLE [dbo].[ta[[b]]le](
|
||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[artifact_name] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NULL,
|
||||
[artifact_content] [varbinary](max) NOT NULL,
|
||||
[artifact_initial_size] [bigint] NULL,
|
||||
[name] [varchar](256) NULL,
|
||||
[version] [varchar](256) NULL,
|
||||
[created] [datetime] NULL,
|
||||
[description] [varchar](256) NULL,
|
||||
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
||||
CONSTRAINT [ta[[b]]le_artifact_pk] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[artifact_id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||
END
|
||||
`;
|
||||
const actual = service.getConfigureQuery(dbName);
|
||||
should.equal(actual.indexOf(expected) > 0, true);
|
||||
const actual = service.getConfigureTableQuery(testContext.importTable);
|
||||
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
|
||||
});
|
||||
|
||||
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
|
||||
@@ -340,23 +361,23 @@ describe('DeployedModelService', () => {
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
testContext.importTable.databaseName = 'd[]b';
|
||||
testContext.importTable.tableName = 'ta[b]le';
|
||||
testContext.importTable.schema = 'dbo';
|
||||
const expected = `
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||
Order by artifact_id
|
||||
`;
|
||||
const actual = service.getDeployedModelsQuery();
|
||||
const actual = service.getDeployedModelsQuery(testContext.importTable);
|
||||
should.deepEqual(expected, actual);
|
||||
});
|
||||
|
||||
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const dbName = 'curre[n]tDb';
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
@@ -364,28 +385,27 @@ describe('DeployedModelService', () => {
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
|
||||
const expected = `
|
||||
Insert into [dbo].[ta[[b]]le]
|
||||
(artifact_name, group_path, artifact_content, name, version, description)
|
||||
Insert into [dbo].[tb]
|
||||
(artifact_name, artifact_content, name, version, description)
|
||||
values (
|
||||
'name1',
|
||||
'ADS',
|
||||
,
|
||||
'title1',
|
||||
'1.1',
|
||||
'desc1')`;
|
||||
const actual = service.getInsertModelQuery(dbName, model);
|
||||
const actual = service.getInsertModelQuery(model, testContext.importTable);
|
||||
should.equal(actual.indexOf(expected) > 0, true);
|
||||
});
|
||||
|
||||
@@ -398,17 +418,19 @@ describe('DeployedModelService', () => {
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
model.table = {
|
||||
databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo'
|
||||
};
|
||||
const expected = `
|
||||
SELECT artifact_content
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
|
||||
@@ -8,7 +8,6 @@ import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
@@ -22,7 +21,7 @@ import * as fs from 'fs';
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
importTable: DatabaseTable;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
}
|
||||
|
||||
@@ -30,7 +29,11 @@ function createContext(): TestContext {
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
importTable: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
},
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
|
||||
};
|
||||
}
|
||||
@@ -49,8 +52,7 @@ describe('PredictService', () => {
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
testContext.queryRunner.object);
|
||||
const actual = await service.getDatabaseList();
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
@@ -102,8 +104,7 @@ describe('PredictService', () => {
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
testContext.queryRunner.object);
|
||||
const actual = await service.getTableList('db1');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
@@ -160,8 +161,7 @@ describe('PredictService', () => {
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
testContext.queryRunner.object);
|
||||
const actual = await service.getTableColumnsList(table);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
@@ -201,13 +201,13 @@ describe('PredictService', () => {
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
testContext.queryRunner.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
@@ -270,8 +270,7 @@ describe('PredictService', () => {
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
testContext.queryRunner.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
|
||||
@@ -34,6 +34,6 @@ describe('Dashboard widget', () => {
|
||||
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '');
|
||||
dashboard.register();
|
||||
testContext.onClick.fire();
|
||||
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atMostOnce());
|
||||
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce());
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { ModelManagementController } from '../../../views/models/modelManagementController';
|
||||
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||
|
||||
const accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
accountId: '1',
|
||||
providerId: ''
|
||||
},
|
||||
displayInfo: {
|
||||
displayName: 'account',
|
||||
userId: '',
|
||||
accountType: '',
|
||||
contextualDisplayName: ''
|
||||
},
|
||||
isStale: false,
|
||||
properties: []
|
||||
}
|
||||
];
|
||||
const subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||
{
|
||||
name: 'subscription',
|
||||
id: '2'
|
||||
}
|
||||
];
|
||||
const groups: azureResource.AzureResourceResourceGroup[] = [
|
||||
{
|
||||
name: 'group',
|
||||
id: '3'
|
||||
}
|
||||
];
|
||||
const workspaces: Workspace[] = [
|
||||
{
|
||||
name: 'workspace',
|
||||
id: '4'
|
||||
}
|
||||
];
|
||||
const models: WorkspaceModel[] = [
|
||||
{
|
||||
id: '5',
|
||||
name: 'model'
|
||||
}
|
||||
];
|
||||
const localModels: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model',
|
||||
title: 'model',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
const dbNames: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const tableNames: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
const columnNames: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const modelParameters: ModelParameters = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
}
|
||||
]
|
||||
};
|
||||
describe('Model Controller', () => {
|
||||
|
||||
it('Should open deploy model wizard successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
|
||||
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
|
||||
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
|
||||
databaseName: 'db',
|
||||
tableName: 'table',
|
||||
schema: 'dbo'
|
||||
}));
|
||||
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
|
||||
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames));
|
||||
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames));
|
||||
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
|
||||
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
|
||||
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
|
||||
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
|
||||
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
|
||||
|
||||
const view = await controller.registerModel(undefined);
|
||||
should.notEqual(view, undefined);
|
||||
});
|
||||
|
||||
it('Should open predict wizard successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
|
||||
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
|
||||
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
|
||||
databaseName: 'db',
|
||||
tableName: 'table',
|
||||
schema: 'dbo'
|
||||
}));
|
||||
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
|
||||
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve([
|
||||
'db', 'db1'
|
||||
]));
|
||||
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve([
|
||||
{ tableName: 'tb', databaseName: 'db', schema: 'dbo' }
|
||||
]));
|
||||
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
|
||||
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
|
||||
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
|
||||
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
|
||||
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
|
||||
testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames));
|
||||
testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters));
|
||||
testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
|
||||
testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
|
||||
|
||||
const view = await controller.predictModel();
|
||||
should.notEqual(view, undefined);
|
||||
});
|
||||
});
|
||||
@@ -34,6 +34,11 @@ describe('Predict Wizard', () => {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||
view.importTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
};
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
@@ -79,7 +84,12 @@ describe('Predict Wizard', () => {
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model',
|
||||
title: 'model'
|
||||
title: 'model',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
const dbNames: string[] = [
|
||||
|
||||
@@ -7,21 +7,25 @@ import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType } from '../../../views/models/modelViewBase';
|
||||
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard';
|
||||
import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard';
|
||||
|
||||
describe('Register Model Wizard', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
||||
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
|
||||
view.importTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'table',
|
||||
schema: 'dbo'
|
||||
};
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
});
|
||||
@@ -29,7 +33,12 @@ describe('Register Model Wizard', () => {
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
||||
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
|
||||
view.importTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
};
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
@@ -75,12 +84,27 @@ describe('Register Model Wizard', () => {
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model',
|
||||
title: 'model'
|
||||
title: 'model',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [
|
||||
'db', 'db1'
|
||||
] });
|
||||
});
|
||||
view.on(ListTableNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [
|
||||
'tb', 'tb1'
|
||||
] });
|
||||
});
|
||||
view.on(ListAccountsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
||||
});
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog';
|
||||
import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog';
|
||||
import { ListModelsEventName } from '../../../views/models/modelViewBase';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
@@ -15,7 +15,7 @@ describe('Registered Models Dialog', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisteredModelsDialog(testContext.apiWrapper.object, '');
|
||||
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
|
||||
should.notEqual(view.dialogView, undefined);
|
||||
@@ -25,13 +25,18 @@ describe('Registered Models Dialog', () => {
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisteredModelsDialog(testContext.apiWrapper.object, '');
|
||||
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
let models: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model',
|
||||
title: ''
|
||||
title: '',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
|
||||
@@ -9,11 +9,17 @@ import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { createViewContext } from '../utils';
|
||||
import { ModelViewBase } from '../../../views/models/modelViewBase';
|
||||
import { AzureModelRegistryService } from '../../../modelManagement/azureModelRegistryService';
|
||||
import { DeployedModelService } from '../../../modelManagement/deployedModelService';
|
||||
import { PredictService } from '../../../prediction/predictService';
|
||||
|
||||
export interface TestContext {
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
view: azdata.ModelView;
|
||||
onClick: vscode.EventEmitter<any>;
|
||||
azureModelService: TypeMoq.IMock<AzureModelRegistryService>;
|
||||
deployModelService: TypeMoq.IMock<DeployedModelService>;
|
||||
predictService: TypeMoq.IMock<PredictService>;
|
||||
}
|
||||
|
||||
export class ParentDialog extends ModelViewBase {
|
||||
@@ -36,6 +42,9 @@ export function createContext(): TestContext {
|
||||
return {
|
||||
apiWrapper: viewTestContext.apiWrapper,
|
||||
view: viewTestContext.view,
|
||||
onClick: viewTestContext.onClick
|
||||
onClick: viewTestContext.onClick,
|
||||
azureModelService: TypeMoq.Mock.ofType(AzureModelRegistryService),
|
||||
deployModelService: TypeMoq.Mock.ofType(DeployedModelService),
|
||||
predictService: TypeMoq.Mock.ofType(PredictService)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -31,6 +31,11 @@ export function createViewContext(): ViewTestContext {
|
||||
let button: azdata.ButtonComponent = Object.assign({}, componentBase, {
|
||||
onDidClick: onClick.event
|
||||
});
|
||||
let link: azdata.HyperlinkComponent = Object.assign({}, componentBase, {
|
||||
onDidClick: onClick.event,
|
||||
label: '',
|
||||
url: ''
|
||||
});
|
||||
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
|
||||
checked: true,
|
||||
onDidClick: onClick.event
|
||||
@@ -61,6 +66,11 @@ export function createViewContext(): ViewTestContext {
|
||||
withProperties: () => buttonBuilder,
|
||||
withValidation: () => buttonBuilder
|
||||
};
|
||||
let hyperLinkBuilder: azdata.ComponentBuilder<azdata.HyperlinkComponent> = {
|
||||
component: () => link,
|
||||
withProperties: () => hyperLinkBuilder,
|
||||
withValidation: () => hyperLinkBuilder
|
||||
};
|
||||
let radioButtonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
|
||||
component: () => radioButton,
|
||||
withProperties: () => radioButtonBuilder,
|
||||
@@ -72,7 +82,7 @@ export function createViewContext(): ViewTestContext {
|
||||
withValidation: () => checkBoxBuilder
|
||||
};
|
||||
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
|
||||
onTextChanged: undefined!,
|
||||
onTextChanged: onClick.event!,
|
||||
onEnterKeyPressed: undefined!,
|
||||
value: ''
|
||||
});
|
||||
@@ -216,7 +226,7 @@ export function createViewContext(): ViewTestContext {
|
||||
toolbarContainer: undefined!,
|
||||
loadingComponent: () => loadingBuilder,
|
||||
fileBrowserTree: undefined!,
|
||||
hyperlink: undefined!,
|
||||
hyperlink: () => hyperLinkBuilder,
|
||||
tabbedPanel: undefined!,
|
||||
separator: undefined!
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user