Machine Learning Services - Model detection in predict wizard (#9609)

* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
Leila Lali
2020-03-25 13:18:19 -07:00
committed by GitHub
parent 176edde2aa
commit ab82c04766
44 changed files with 2265 additions and 376 deletions

View File

@@ -0,0 +1,232 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Config } from '../../configurations/config';
import { HttpClient } from '../../common/httpClient';
import { azureResource } from '../../typings/azure-resource';
import * as utils from '../utils';
import { Workspace, WorkspacesListByResourceGroupResponse } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel, AssetsQueryByIdResponse, Asset, GetArtifactContentInformation2Response } from '../../modelManagement/interfaces';
import { AzureMachineLearningWorkspaces, Workspaces } from '@azure/arm-machinelearningservices';
import { WorkspaceModels } from '../../modelManagement/workspacesModels';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
config: TypeMoq.IMock<Config>;
httpClient: TypeMoq.IMock<HttpClient>;
outputChannel: vscode.OutputChannel;
op: azdata.BackgroundOperation;
accounts: azdata.Account[];
subscriptions: azureResource.AzureResourceSubscription[];
groups: azureResource.AzureResourceResourceGroup[];
workspaces: Workspace[];
models: WorkspaceModel[];
client: TypeMoq.IMock<AzureMachineLearningWorkspaces>;
workspacesClient: TypeMoq.IMock<Workspaces>;
modelClient: TypeMoq.IMock<WorkspaceModels>;
}
function createContext(): TestContext {
const context = utils.createContext();
const workspaces = TypeMoq.Mock.ofType(Workspaces);
const credentials = {
signRequest: () => {
return Promise.resolve(undefined!!);
}
};
const client = TypeMoq.Mock.ofInstance(new AzureMachineLearningWorkspaces(credentials, 'subscription'));
client.setup(x => x.apiVersion).returns(() => '20180101');
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
httpClient: TypeMoq.Mock.ofType(HttpClient),
outputChannel: context.outputChannel,
op: context.op,
accounts: [
{
key: {
providerId: '',
accountId: 'a1'
},
displayInfo: {
contextualDisplayName: '',
accountType: '',
displayName: 'a1',
userId: 'a1'
},
properties:
{
tenants: [
{
id: '1',
}
]
}
,
isStale: true
}
],
subscriptions: [
{
name: 's1',
id: 's1'
}
],
groups: [
{
name: 'g1',
id: 'g1'
}
],
workspaces: [{
name: 'w1',
id: 'w1'
}
],
models: [
{
name: 'm1',
id: 'm1',
url: 'aml://asset/test.test'
}
],
client: client,
workspacesClient: workspaces,
modelClient: TypeMoq.Mock.ofInstance(new WorkspaceModels(client.object))
};
}
describe('AzureModelRegistryService', () => {
it('getAccounts should return the list of accounts successfully', async function (): Promise<void> {
let testContext = createContext();
const accounts = testContext.accounts;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
testContext.apiWrapper.setup(x => x.getAllAccounts()).returns(() => Promise.resolve(accounts));
let actual = await service.getAccounts();
should.deepEqual(actual, testContext.accounts);
});
it('getSubscriptions should return the list of subscriptions successfully', async function (): Promise<void> {
let testContext = createContext();
const expected = testContext.subscriptions;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ subscriptions: expected, errors: [] }));
let actual = await service.getSubscriptions(testContext.accounts[0]);
should.deepEqual(actual, expected);
});
it('getGroups should return the list of groups successfully', async function (): Promise<void> {
let testContext = createContext();
const expected = testContext.groups;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ resourceGroups: expected, errors: [] }));
let actual = await service.getGroups(testContext.accounts[0], testContext.subscriptions[0]);
should.deepEqual(actual, expected);
});
it('getWorkspaces should return the list of workspaces successfully', async function (): Promise<void> {
let testContext = createContext();
const response: WorkspacesListByResourceGroupResponse = Object.assign(new Array<Workspace>(...testContext.workspaces), {
_response: undefined!
});
const expected = testContext.workspaces;
testContext.workspacesClient.setup(x => x.listByResourceGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(response));
testContext.workspacesClient.setup(x => x.listBySubscription()).returns(() => Promise.resolve(response));
testContext.client.setup(x => x.workspaces).returns(() => testContext.workspacesClient.object);
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
service.AzureMachineLearningClient = testContext.client.object;
let actual = await service.getWorkspaces(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0]);
should.deepEqual(actual, expected);
});
it('getModels should return the list of models successfully', async function (): Promise<void> {
let testContext = createContext();
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
const expected = testContext.models;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
service.AzureMachineLearningClient = testContext.client.object;
service.ModelClient = testContext.modelClient.object;
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
let actual = await service.getModels(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0]);
should.deepEqual(actual, expected);
});
it('downloadModel should download model artifact successfully', async function (): Promise<void> {
let testContext = createContext();
const asset: Asset =
{
id: '1',
name: 'asset',
artifacts: [
{
id: '/1/2/3/4/5/'
}
]
};
const assetResponse: AssetsQueryByIdResponse = Object.assign(asset, {
_response: undefined!
});
const artifactResponse: GetArtifactContentInformation2Response = Object.assign({
contentUri: 'downloadUrl'
}, {
_response: undefined!
});
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
testContext.config.setup(x => x.amlExperienceUrl).returns(() => 'test.url');
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('modelmanagement')), TypeMoq.It.isAny())).returns(() => Promise.resolve(assetResponse));
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('artifact')), TypeMoq.It.isAny())).returns(() => Promise.resolve(artifactResponse));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
service.AzureMachineLearningClient = testContext.client.object;
service.ModelClient = testContext.modelClient.object;
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
let actual = await service.downloadModel(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0], testContext.models[0]);
should.notEqual(actual, undefined);
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
});

View File

@@ -0,0 +1,410 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as utils from '../../common/utils';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { Config } from '../../configurations/config';
import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { QueryRunner } from '../../common/queryRunner';
import { RegisteredModel } from '../../modelManagement/interfaces';
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
import * as path from 'path';
import * as os from 'os';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as fs from 'fs';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
config: TypeMoq.IMock<Config>;
queryRunner: TypeMoq.IMock<QueryRunner>;
modelClient: TypeMoq.IMock<ModelPythonClient>;
}
function createContext(): TestContext {
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
};
}
describe('DeployedModelService', () => {
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
const testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
await should(service.getDeployedModels()).rejected();
});
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const expected: RegisteredModel[] = [
{
id: 1,
artifactName: 'name1',
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
}
];
const result = {
rowCount: 1,
columnInfo: [],
rows: [
[
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'name1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'title1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'desc1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1.1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
}
]
]
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
const actual = await service.getDeployedModels();
should.deepEqual(actual, expected);
});
it('loadModelParameters should load parameters using python client successfully', async function (): Promise<void> {
const testContext = createContext();
const expected = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
},
]
};
testContext.modelClient.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(expected));
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
const actual = await service.loadModelParameters('');
should.deepEqual(actual, expected);
});
it('downloadModel should download model successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, 'test');
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const model: RegisteredModel =
{
id: 1,
artifactName: 'name1',
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
};
const result = {
rowCount: 1,
columnInfo: [],
rows: [
[
{
displayValue: await utils.readFileInHex(tempFilePath),
isNull: false,
invariantCultureDisplayValue: ''
}
]
]
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
const actual = await service.downloadModel(model);
should.notEqual(actual, undefined);
});
it('deployLocalModel should returns models successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const model: RegisteredModel =
{
id: 1,
artifactName: 'name1',
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
};
const row = [
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'name1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'title1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'desc1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1.1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
}
];
const result = {
rowCount: 1,
columnInfo: [],
rows: [row]
};
let updatedResult = {
rowCount: 1,
columnInfo: [],
rows: [row, row]
};
let deployed = false;
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => {
deployed = true;
return Promise.resolve();
});
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
});
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
await should(service.deployLocalModel('', model)).resolved();
});
it('getConfigureQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const dbName = 'curre[n]tDb';
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
const expected = `
IF NOT EXISTS (
SELECT [name]
FROM sys.databases
WHERE [name] = N'd[]b'
)
CREATE DATABASE [d[[]]b]
GO
USE [d[[]]b]
IF EXISTS
( SELECT [t.name], [s.name]
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE [t.name] = 'ta[b]le'
AND [s.name] = 'dbo'
)
BEGIN
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
ALTER TABLE [dbo].[ta[[b]]le] ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='created')
BEGIN
ALTER TABLE [dbo].[ta[[b]]le] ADD [created] [datetime] NULL
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
END
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='description')
ALTER TABLE [dbo].[ta[[b]]le] ADD [description] [varchar](256) NULL
END
Else
BEGIN
CREATE TABLE [dbo].[ta[[b]]le](
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
[artifact_name] [varchar](256) NOT NULL,
[group_path] [varchar](256) NOT NULL,
[artifact_content] [varbinary](max) NOT NULL,
[artifact_initial_size] [bigint] NULL,
[name] [varchar](256) NULL,
[version] [varchar](256) NULL,
[created] [datetime] NULL,
[description] [varchar](256) NULL,
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
(
[artifact_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
END
`;
const actual = service.getConfigureQuery(dbName);
should.equal(actual.indexOf(expected) > 0, true);
});
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
const expected = `
SELECT artifact_id, artifact_name, name, description, version, created
FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
const actual = service.getDeployedModelsQuery();
should.deepEqual(expected, actual);
});
it('getUpdateModelQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const dbName = 'curre[n]tDb';
const model: RegisteredModel =
{
id: 1,
artifactName: 'name1',
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
const expected = `
UPDATE [dbo].[ta[[b]]le]
SET
name = 'title1',
version = '1.1',
description = 'desc1'
WHERE artifact_id = 1`;
const actual = service.getUpdateModelQuery(dbName, model);
should.equal(actual.indexOf(expected) > 0, true);
//should.deepEqual(actual, expected);
});
it('getModelContentQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const model: RegisteredModel =
{
id: 1,
artifactName: 'name1',
title: 'title1',
description: 'desc1',
created: '2018-01-01',
version: '1.1'
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object);
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
const expected = `
SELECT artifact_content
FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE artifact_id = 1;
`;
const actual = service.getModelContentQuery(model);
should.deepEqual(actual, expected);
});
});

View File

@@ -0,0 +1,121 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { Config } from '../../configurations/config';
import * as utils from '../utils';
import { ProcessService } from '../../common/processService';
import { PackageManager } from '../../packageManagement/packageManager';
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
config: TypeMoq.IMock<Config>;
outputChannel: vscode.OutputChannel;
op: azdata.BackgroundOperation;
processService: TypeMoq.IMock<ProcessService>;
packageManager: TypeMoq.IMock<PackageManager>;
}
function createContext(): TestContext {
const context = utils.createContext();
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
outputChannel: context.outputChannel,
op: context.op,
processService: TypeMoq.Mock.ofType(ProcessService),
packageManager: TypeMoq.Mock.ofType(PackageManager)
};
}
describe('ModelPythonClient', () => {
it('deployModel should deploy the model successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
const modelPath = 'C:\\test';
let service = new ModelPythonClient(
testContext.outputChannel,
testContext.apiWrapper.object,
testContext.processService.object,
testContext.config.object,
testContext.packageManager.object);
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
await service.deployModel(connection, modelPath);
});
it('loadModelParameters should load model parameters successfully', async function (): Promise<void> {
const testContext = createContext();
const modelPath = 'C:\\test';
const expected = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
},
]
};
const parametersJson = `
{
"inputs": [
{
"name": "p1",
"type": "int"
},
{
"name": "p2",
"type": "varchar"
}
],
"outputs": [
{
"name": "o1",
"type": "int"
}
]
}
`;
let service = new ModelPythonClient(
testContext.outputChannel,
testContext.apiWrapper.object,
testContext.processService.object,
testContext.config.object,
testContext.packageManager.object);
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
const actual = await service.loadModelParameters(modelPath);
should.deepEqual(actual, expected);
});
});