ML - Target import table selectable by user (#10071)

ML - Target import table selectable by user
This commit is contained in:
Leila Lali
2020-04-21 08:02:48 -07:00
committed by GitHub
parent 4f1d4276a0
commit a34feb4448
30 changed files with 1172 additions and 317 deletions

View File

@@ -17,6 +17,8 @@ import * as path from 'path';
import * as os from 'os';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as fs from 'fs';
import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent';
import { DatabaseTable } from '../../prediction/interfaces';
interface TestContext {
@@ -24,6 +26,8 @@ interface TestContext {
config: TypeMoq.IMock<Config>;
queryRunner: TypeMoq.IMock<QueryRunner>;
modelClient: TypeMoq.IMock<ModelPythonClient>;
recentModels: TypeMoq.IMock<ModelConfigRecent>;
importTable: DatabaseTable;
}
function createContext(): TestContext {
@@ -32,7 +36,13 @@ function createContext(): TestContext {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
modelClient: TypeMoq.Mock.ofType(ModelPythonClient),
recentModels: TypeMoq.Mock.ofType(ModelConfigRecent),
importTable: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
};
}
@@ -40,14 +50,20 @@ describe('DeployedModelService', () => {
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
const testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
let importTable: DatabaseTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
await should(service.getDeployedModels()).rejected();
testContext.modelClient.object,
testContext.recentModels.object);
await should(service.getDeployedModels(importTable)).rejected();
});
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
@@ -61,7 +77,9 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
}
];
const result = {
@@ -106,12 +124,13 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
const actual = await service.getDeployedModels();
const actual = await service.getDeployedModels(testContext.importTable);
should.deepEqual(actual, expected);
});
@@ -140,7 +159,8 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
const actual = await service.loadModelParameters('');
should.deepEqual(actual, expected);
});
@@ -158,7 +178,8 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
const result = {
rowCount: 1,
@@ -177,7 +198,8 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
@@ -198,7 +220,8 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
const row = [
{
@@ -247,15 +270,17 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => {
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0), TypeMoq.It.isAny())).returns(() => {
deployed = true;
return Promise.resolve(result);
});
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
});
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
@@ -264,7 +289,7 @@ describe('DeployedModelService', () => {
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, 'test');
await should(service.deployLocalModel(tempFilePath, model)).resolved();
await should(service.deployLocalModel(tempFilePath, model, testContext.importTable)).resolved();
}
finally {
await utils.deleteFile(tempFilePath);
@@ -273,31 +298,28 @@ describe('DeployedModelService', () => {
it('getConfigureQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const dbName = 'curre[n]tDb';
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo';
const expected = `
IF NOT EXISTS (
SELECT [name]
FROM sys.databases
WHERE [name] = N'd[]b'
)
CREATE DATABASE [d[[]]b]
GO
USE [d[[]]b]
IF EXISTS
( SELECT [t.name], [s.name]
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE [t.name] = 'ta[b]le'
AND [s.name] = 'dbo'
WHERE t.name = 'ta[b]le'
AND s.name = 'dbo'
)
BEGIN
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_name')
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_name] [varchar](256) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_content')
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_content] [varbinary](max) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
@@ -315,23 +337,22 @@ describe('DeployedModelService', () => {
CREATE TABLE [dbo].[ta[[b]]le](
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
[artifact_name] [varchar](256) NOT NULL,
[group_path] [varchar](256) NULL,
[artifact_content] [varbinary](max) NOT NULL,
[artifact_initial_size] [bigint] NULL,
[name] [varchar](256) NULL,
[version] [varchar](256) NULL,
[created] [datetime] NULL,
[description] [varchar](256) NULL,
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
CONSTRAINT [ta[[b]]le_artifact_pk] PRIMARY KEY CLUSTERED
(
[artifact_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
END
`;
const actual = service.getConfigureQuery(dbName);
should.equal(actual.indexOf(expected) > 0, true);
const actual = service.getConfigureTableQuery(testContext.importTable);
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
});
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
@@ -340,23 +361,23 @@ describe('DeployedModelService', () => {
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo';
const expected = `
SELECT artifact_id, artifact_name, name, description, version, created
FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
const actual = service.getDeployedModelsQuery();
const actual = service.getDeployedModelsQuery(testContext.importTable);
should.deepEqual(expected, actual);
});
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const dbName = 'curre[n]tDb';
const model: RegisteredModel =
{
id: 1,
@@ -364,28 +385,27 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
const expected = `
Insert into [dbo].[ta[[b]]le]
(artifact_name, group_path, artifact_content, name, version, description)
Insert into [dbo].[tb]
(artifact_name, artifact_content, name, version, description)
values (
'name1',
'ADS',
,
'title1',
'1.1',
'desc1')`;
const actual = service.getInsertModelQuery(dbName, model);
const actual = service.getInsertModelQuery(model, testContext.importTable);
should.equal(actual.indexOf(expected) > 0, true);
});
@@ -398,17 +418,19 @@ describe('DeployedModelService', () => {
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
version: '1.1',
table: testContext.importTable
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
testContext.modelClient.object,
testContext.recentModels.object);
model.table = {
databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo'
};
const expected = `
SELECT artifact_content
FROM [d[[]]b].[dbo].[ta[[b]]le]