mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-11 10:38:31 -05:00
Machine Learning Services - Model detection in predict wizard (#9609)
* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
@@ -0,0 +1,232 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { HttpClient } from '../../common/httpClient';
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
|
||||
import * as utils from '../utils';
|
||||
import { Workspace, WorkspacesListByResourceGroupResponse } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { WorkspaceModel, AssetsQueryByIdResponse, Asset, GetArtifactContentInformation2Response } from '../../modelManagement/interfaces';
|
||||
import { AzureMachineLearningWorkspaces, Workspaces } from '@azure/arm-machinelearningservices';
|
||||
import { WorkspaceModels } from '../../modelManagement/workspacesModels';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
httpClient: TypeMoq.IMock<HttpClient>;
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
accounts: azdata.Account[];
|
||||
subscriptions: azureResource.AzureResourceSubscription[];
|
||||
groups: azureResource.AzureResourceResourceGroup[];
|
||||
workspaces: Workspace[];
|
||||
models: WorkspaceModel[];
|
||||
client: TypeMoq.IMock<AzureMachineLearningWorkspaces>;
|
||||
workspacesClient: TypeMoq.IMock<Workspaces>;
|
||||
modelClient: TypeMoq.IMock<WorkspaceModels>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
const context = utils.createContext();
|
||||
const workspaces = TypeMoq.Mock.ofType(Workspaces);
|
||||
const credentials = {
|
||||
signRequest: () => {
|
||||
return Promise.resolve(undefined!!);
|
||||
}
|
||||
};
|
||||
const client = TypeMoq.Mock.ofInstance(new AzureMachineLearningWorkspaces(credentials, 'subscription'));
|
||||
client.setup(x => x.apiVersion).returns(() => '20180101');
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
httpClient: TypeMoq.Mock.ofType(HttpClient),
|
||||
outputChannel: context.outputChannel,
|
||||
op: context.op,
|
||||
accounts: [
|
||||
{
|
||||
key: {
|
||||
providerId: '',
|
||||
accountId: 'a1'
|
||||
},
|
||||
displayInfo: {
|
||||
contextualDisplayName: '',
|
||||
accountType: '',
|
||||
displayName: 'a1',
|
||||
userId: 'a1'
|
||||
},
|
||||
properties:
|
||||
{
|
||||
tenants: [
|
||||
{
|
||||
id: '1',
|
||||
}
|
||||
]
|
||||
}
|
||||
,
|
||||
isStale: true
|
||||
}
|
||||
],
|
||||
subscriptions: [
|
||||
{
|
||||
name: 's1',
|
||||
id: 's1'
|
||||
}
|
||||
],
|
||||
groups: [
|
||||
{
|
||||
name: 'g1',
|
||||
id: 'g1'
|
||||
}
|
||||
],
|
||||
workspaces: [{
|
||||
name: 'w1',
|
||||
id: 'w1'
|
||||
}
|
||||
],
|
||||
models: [
|
||||
{
|
||||
name: 'm1',
|
||||
id: 'm1',
|
||||
url: 'aml://asset/test.test'
|
||||
}
|
||||
],
|
||||
client: client,
|
||||
workspacesClient: workspaces,
|
||||
modelClient: TypeMoq.Mock.ofInstance(new WorkspaceModels(client.object))
|
||||
};
|
||||
}
|
||||
|
||||
describe('AzureModelRegistryService', () => {
|
||||
it('getAccounts should return the list of accounts successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const accounts = testContext.accounts;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.getAllAccounts()).returns(() => Promise.resolve(accounts));
|
||||
let actual = await service.getAccounts();
|
||||
should.deepEqual(actual, testContext.accounts);
|
||||
});
|
||||
|
||||
it('getSubscriptions should return the list of subscriptions successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const expected = testContext.subscriptions;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ subscriptions: expected, errors: [] }));
|
||||
let actual = await service.getSubscriptions(testContext.accounts[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getGroups should return the list of groups successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const expected = testContext.groups;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ resourceGroups: expected, errors: [] }));
|
||||
let actual = await service.getGroups(testContext.accounts[0], testContext.subscriptions[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getWorkspaces should return the list of workspaces successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const response: WorkspacesListByResourceGroupResponse = Object.assign(new Array<Workspace>(...testContext.workspaces), {
|
||||
_response: undefined!
|
||||
});
|
||||
const expected = testContext.workspaces;
|
||||
testContext.workspacesClient.setup(x => x.listByResourceGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(response));
|
||||
testContext.workspacesClient.setup(x => x.listBySubscription()).returns(() => Promise.resolve(response));
|
||||
testContext.client.setup(x => x.workspaces).returns(() => testContext.workspacesClient.object);
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
|
||||
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
let actual = await service.getWorkspaces(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getModels should return the list of models successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
|
||||
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
|
||||
const expected = testContext.models;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
service.ModelClient = testContext.modelClient.object;
|
||||
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
|
||||
let actual = await service.getModels(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('downloadModel should download model artifact successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const asset: Asset =
|
||||
{
|
||||
id: '1',
|
||||
name: 'asset',
|
||||
artifacts: [
|
||||
{
|
||||
id: '/1/2/3/4/5/'
|
||||
}
|
||||
]
|
||||
};
|
||||
const assetResponse: AssetsQueryByIdResponse = Object.assign(asset, {
|
||||
_response: undefined!
|
||||
});
|
||||
const artifactResponse: GetArtifactContentInformation2Response = Object.assign({
|
||||
contentUri: 'downloadUrl'
|
||||
}, {
|
||||
_response: undefined!
|
||||
});
|
||||
|
||||
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
|
||||
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
|
||||
testContext.config.setup(x => x.amlExperienceUrl).returns(() => 'test.url');
|
||||
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
|
||||
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('modelmanagement')), TypeMoq.It.isAny())).returns(() => Promise.resolve(assetResponse));
|
||||
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
|
||||
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('artifact')), TypeMoq.It.isAny())).returns(() => Promise.resolve(artifactResponse));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
service.ModelClient = testContext.modelClient.object;
|
||||
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
|
||||
let actual = await service.downloadModel(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0], testContext.models[0]);
|
||||
should.notEqual(actual, undefined);
|
||||
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,410 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as utils from '../../common/utils';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as fs from 'fs';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
modelClient: TypeMoq.IMock<ModelPythonClient>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
|
||||
};
|
||||
}
|
||||
|
||||
describe('DeployedModelService', () => {
|
||||
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
await should(service.getDeployedModels()).rejected();
|
||||
});
|
||||
|
||||
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const expected: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
}
|
||||
];
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [
|
||||
[
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'name1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'title1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'desc1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1.1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]
|
||||
]
|
||||
};
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
const actual = await service.getDeployedModels();
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('loadModelParameters should load parameters using python client successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
},
|
||||
]
|
||||
};
|
||||
testContext.modelClient.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(expected));
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
const actual = await service.loadModelParameters('');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('downloadModel should download model successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [
|
||||
[
|
||||
{
|
||||
displayValue: await utils.readFileInHex(tempFilePath),
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]
|
||||
]
|
||||
};
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const actual = await service.downloadModel(model);
|
||||
should.notEqual(actual, undefined);
|
||||
});
|
||||
|
||||
it('deployLocalModel should returns models successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
const row = [
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'name1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'title1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'desc1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1.1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
];
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [row]
|
||||
};
|
||||
let updatedResult = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [row, row]
|
||||
};
|
||||
let deployed = false;
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => {
|
||||
deployed = true;
|
||||
return Promise.resolve();
|
||||
});
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
||||
});
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
await should(service.deployLocalModel('', model)).resolved();
|
||||
});
|
||||
|
||||
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const dbName = 'curre[n]tDb';
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
IF NOT EXISTS (
|
||||
SELECT [name]
|
||||
FROM sys.databases
|
||||
WHERE [name] = N'd[]b'
|
||||
)
|
||||
CREATE DATABASE [d[[]]b]
|
||||
GO
|
||||
USE [d[[]]b]
|
||||
IF EXISTS
|
||||
( SELECT [t.name], [s.name]
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE [t.name] = 'ta[b]le'
|
||||
AND [s.name] = 'dbo'
|
||||
)
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='created')
|
||||
BEGIN
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [created] [datetime] NULL
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
END
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='description')
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [description] [varchar](256) NULL
|
||||
END
|
||||
Else
|
||||
BEGIN
|
||||
CREATE TABLE [dbo].[ta[[b]]le](
|
||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[artifact_name] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NOT NULL,
|
||||
[artifact_content] [varbinary](max) NOT NULL,
|
||||
[artifact_initial_size] [bigint] NULL,
|
||||
[name] [varchar](256) NULL,
|
||||
[version] [varchar](256) NULL,
|
||||
[created] [datetime] NULL,
|
||||
[description] [varchar](256) NULL,
|
||||
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[artifact_id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||
END
|
||||
`;
|
||||
const actual = service.getConfigureQuery(dbName);
|
||||
should.equal(actual.indexOf(expected) > 0, true);
|
||||
});
|
||||
|
||||
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||
Order by artifact_id
|
||||
`;
|
||||
const actual = service.getDeployedModelsQuery();
|
||||
should.deepEqual(expected, actual);
|
||||
});
|
||||
|
||||
it('getUpdateModelQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const dbName = 'curre[n]tDb';
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
UPDATE [dbo].[ta[[b]]le]
|
||||
SET
|
||||
name = 'title1',
|
||||
version = '1.1',
|
||||
description = 'desc1'
|
||||
WHERE artifact_id = 1`;
|
||||
const actual = service.getUpdateModelQuery(dbName, model);
|
||||
should.equal(actual.indexOf(expected) > 0, true);
|
||||
//should.deepEqual(actual, expected);
|
||||
|
||||
});
|
||||
|
||||
it('getModelContentQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
SELECT artifact_content
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
WHERE artifact_id = 1;
|
||||
`;
|
||||
const actual = service.getModelContentQuery(model);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,121 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
|
||||
import * as utils from '../utils';
|
||||
import { ProcessService } from '../../common/processService';
|
||||
import { PackageManager } from '../../packageManagement/packageManager';
|
||||
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
processService: TypeMoq.IMock<ProcessService>;
|
||||
packageManager: TypeMoq.IMock<PackageManager>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
const context = utils.createContext();
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
outputChannel: context.outputChannel,
|
||||
op: context.op,
|
||||
processService: TypeMoq.Mock.ofType(ProcessService),
|
||||
packageManager: TypeMoq.Mock.ofType(PackageManager)
|
||||
};
|
||||
}
|
||||
|
||||
describe('ModelPythonClient', () => {
|
||||
it('deployModel should deploy the model successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
const modelPath = 'C:\\test';
|
||||
let service = new ModelPythonClient(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.packageManager.object);
|
||||
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
||||
|
||||
await service.deployModel(connection, modelPath);
|
||||
});
|
||||
|
||||
it('loadModelParameters should load model parameters successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const modelPath = 'C:\\test';
|
||||
const expected = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
},
|
||||
]
|
||||
};
|
||||
const parametersJson = `
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"name": "p1",
|
||||
"type": "int"
|
||||
},
|
||||
{
|
||||
"name": "p2",
|
||||
"type": "varchar"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "o1",
|
||||
"type": "int"
|
||||
}
|
||||
]
|
||||
}
|
||||
`;
|
||||
let service = new ModelPythonClient(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.packageManager.object);
|
||||
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
|
||||
const actual = await service.loadModelParameters(modelPath);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
});
|
||||
@@ -354,7 +354,7 @@ describe('SQL Python Package Manager', () => {
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocationTitle();
|
||||
|
||||
should.deepEqual(actual, constants.packageManagerNoConnection);
|
||||
should.deepEqual(actual, constants.noConnectionError);
|
||||
});
|
||||
|
||||
it('getLocationTitle Should return connection title string for valid connection', async function (): Promise<void> {
|
||||
|
||||
@@ -279,7 +279,7 @@ describe('SQL R Package Manager', () => {
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocationTitle();
|
||||
|
||||
should.deepEqual(actual, constants.packageManagerNoConnection);
|
||||
should.deepEqual(actual, constants.noConnectionError);
|
||||
});
|
||||
|
||||
it('getLocationTitle Should return connection title string for valid connection', async function (): Promise<void> {
|
||||
|
||||
@@ -11,6 +11,7 @@ import { QueryRunner } from '../../common/queryRunner';
|
||||
import { ProcessService } from '../../common/processService';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { HttpClient } from '../../common/httpClient';
|
||||
import * as utils from '../utils';
|
||||
import { PackageManagementService } from '../../packageManagement/packageManagementService';
|
||||
|
||||
export interface TestContext {
|
||||
@@ -27,31 +28,18 @@ export interface TestContext {
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
let opStatus: azdata.TaskStatus;
|
||||
const context = utils.createContext();
|
||||
|
||||
return {
|
||||
outputChannel: {
|
||||
name: '',
|
||||
append: () => { },
|
||||
appendLine: () => { },
|
||||
clear: () => { },
|
||||
show: () => { },
|
||||
hide: () => { },
|
||||
dispose: () => { }
|
||||
},
|
||||
|
||||
outputChannel: context.outputChannel,
|
||||
processService: TypeMoq.Mock.ofType(ProcessService),
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
httpClient: TypeMoq.Mock.ofType(HttpClient),
|
||||
op: {
|
||||
updateStatus: (status: azdata.TaskStatus) => {
|
||||
opStatus = status;
|
||||
},
|
||||
id: '',
|
||||
onCanceled: new vscode.EventEmitter<void>().event,
|
||||
},
|
||||
getOpStatus: () => { return opStatus; },
|
||||
op: context.op,
|
||||
getOpStatus: context.getOpStatus,
|
||||
serverConfigManager: TypeMoq.Mock.ofType(PackageManagementService)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as fs from 'fs';
|
||||
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
|
||||
};
|
||||
}
|
||||
|
||||
describe('PredictService', () => {
|
||||
|
||||
it('getDatabaseList should return databases successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.listDatabases(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(expected); });
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
const actual = await service.getDatabaseList();
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getTableList should return tables successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [[
|
||||
{
|
||||
displayValue: 'tb1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'dbo',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
], [
|
||||
{
|
||||
displayValue: 'tb2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'dbo',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]]
|
||||
};
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
const actual = await service.getTableList('db1');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getTableColumnsList should return table columns successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const table: DatabaseTable =
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
};
|
||||
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [[
|
||||
{
|
||||
displayValue: 'c1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'int',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
], [
|
||||
{
|
||||
displayValue: 'c2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'varchar',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]]
|
||||
};
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
const actual = await service.getTableColumnsList(table);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('generatePredictScript should generate the script successfully using model', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const predictParams: PredictParameters = {
|
||||
inputColumns: [
|
||||
{
|
||||
paramName: 'p1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
dataType: 'varchar',
|
||||
columnName: ''
|
||||
}
|
||||
],
|
||||
outputColumns: [
|
||||
{
|
||||
paramName: 'o1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
],
|
||||
databaseName: '',
|
||||
tableName: '',
|
||||
schema: ''
|
||||
};
|
||||
const model: RegisteredModel =
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'name1',
|
||||
title: 'title1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1'
|
||||
};
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
fileName: '',
|
||||
isUntitled: true,
|
||||
languageId: 'sql',
|
||||
version: 1,
|
||||
isDirty: true,
|
||||
isClosed: false,
|
||||
save: undefined!,
|
||||
eol: undefined!,
|
||||
lineCount: 1,
|
||||
lineAt: undefined!,
|
||||
offsetAt: undefined!,
|
||||
positionAt: undefined!,
|
||||
getText: undefined!,
|
||||
getWordRangeAtPosition: undefined!,
|
||||
validateRange: undefined!,
|
||||
validatePosition: undefined!
|
||||
};
|
||||
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
|
||||
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
|
||||
|
||||
const actual = await service.generatePredictScript(predictParams, model, undefined);
|
||||
should.notEqual(actual, undefined);
|
||||
should.equal(actual.indexOf('FROM PREDICT(MODEL = @model') > 0, true);
|
||||
});
|
||||
|
||||
it('generatePredictScript should generate the script successfully using file', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const predictParams: PredictParameters = {
|
||||
inputColumns: [
|
||||
{
|
||||
paramName: 'p1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
dataType: 'varchar',
|
||||
columnName: ''
|
||||
}
|
||||
],
|
||||
outputColumns: [
|
||||
{
|
||||
paramName: 'o1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
],
|
||||
databaseName: '',
|
||||
tableName: '',
|
||||
schema: ''
|
||||
};
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.config.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
fileName: '',
|
||||
isUntitled: true,
|
||||
languageId: 'sql',
|
||||
version: 1,
|
||||
isDirty: true,
|
||||
isClosed: false,
|
||||
save: undefined!,
|
||||
eol: undefined!,
|
||||
lineCount: 1,
|
||||
lineAt: undefined!,
|
||||
offsetAt: undefined!,
|
||||
positionAt: undefined!,
|
||||
getText: undefined!,
|
||||
getWordRangeAtPosition: undefined!,
|
||||
validateRange: undefined!,
|
||||
validatePosition: undefined!
|
||||
};
|
||||
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
|
||||
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
|
||||
|
||||
const actual = await service.generatePredictScript(predictParams, undefined, tempFilePath);
|
||||
should.notEqual(actual, undefined);
|
||||
should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
|
||||
});
|
||||
});
|
||||
38
extensions/machine-learning-services/src/test/utils.ts
Normal file
38
extensions/machine-learning-services/src/test/utils.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
export interface TestContext {
|
||||
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
getOpStatus: () => azdata.TaskStatus;
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
let opStatus: azdata.TaskStatus;
|
||||
|
||||
return {
|
||||
outputChannel: {
|
||||
name: '',
|
||||
append: () => { },
|
||||
appendLine: () => { },
|
||||
clear: () => { },
|
||||
show: () => { },
|
||||
hide: () => { },
|
||||
dispose: () => { }
|
||||
},
|
||||
op: {
|
||||
updateStatus: (status: azdata.TaskStatus) => {
|
||||
opStatus = status;
|
||||
},
|
||||
id: '',
|
||||
onCanceled: new vscode.EventEmitter<void>().event,
|
||||
},
|
||||
getOpStatus: () => { return opStatus; }
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import {
|
||||
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
|
||||
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName
|
||||
}
|
||||
from '../../../views/models/modelViewBase';
|
||||
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { PredictWizard } from '../../../views/models/prediction/predictWizard';
|
||||
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||
|
||||
describe('Predict Wizard', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||
await view.open();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
});
|
||||
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
accountId: '1',
|
||||
providerId: ''
|
||||
},
|
||||
displayInfo: {
|
||||
displayName: 'account',
|
||||
userId: '',
|
||||
accountType: '',
|
||||
contextualDisplayName: ''
|
||||
},
|
||||
isStale: false,
|
||||
properties: []
|
||||
}
|
||||
];
|
||||
let subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||
{
|
||||
name: 'subscription',
|
||||
id: '2'
|
||||
}
|
||||
];
|
||||
let groups: azureResource.AzureResourceResourceGroup[] = [
|
||||
{
|
||||
name: 'group',
|
||||
id: '3'
|
||||
}
|
||||
];
|
||||
let workspaces: Workspace[] = [
|
||||
{
|
||||
name: 'workspace',
|
||||
id: '4'
|
||||
}
|
||||
];
|
||||
let models: WorkspaceModel[] = [
|
||||
{
|
||||
id: '5',
|
||||
name: 'model'
|
||||
}
|
||||
];
|
||||
let localModels: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
artifactName: 'model',
|
||||
title: 'model'
|
||||
}
|
||||
];
|
||||
const dbNames: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const tableNames: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
const columnNames: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const modelParameters: ModelParameters = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
view.on(ListModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
|
||||
});
|
||||
view.on(ListAccountsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
||||
});
|
||||
view.on(ListSubscriptionsEventName, () => {
|
||||
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
|
||||
});
|
||||
view.on(ListGroupsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
|
||||
});
|
||||
view.on(ListWorkspacesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
|
||||
});
|
||||
view.on(ListAzureModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames });
|
||||
});
|
||||
view.on(ListTableNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames });
|
||||
});
|
||||
view.on(ListColumnNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames });
|
||||
});
|
||||
view.on(LoadModelParametersEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters });
|
||||
});
|
||||
view.on(DownloadAzureModelEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' });
|
||||
});
|
||||
view.on(DownloadRegisteredModelEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
|
||||
});
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data, undefined);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
|
||||
should.notEqual(await view.getModelFileName(), undefined);
|
||||
await view.columnsSelectionPage?.onEnter();
|
||||
|
||||
should.notEqual(view.columnsSelectionPage?.data, undefined);
|
||||
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name);
|
||||
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length);
|
||||
});
|
||||
});
|
||||
@@ -20,7 +20,7 @@ describe('Register Model Wizard', () => {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
@@ -30,7 +30,7 @@ describe('Register Model Wizard', () => {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
@@ -98,5 +98,7 @@ describe('Register Model Wizard', () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data ,undefined);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,14 +7,12 @@ import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as mssql from '../../../../../mssql/src/mssql';
|
||||
import { createViewContext } from '../utils';
|
||||
import { ModelViewBase } from '../../../views/models/modelViewBase';
|
||||
|
||||
export interface TestContext {
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
view: azdata.ModelView;
|
||||
languageExtensionService: mssql.ILanguageExtensionService;
|
||||
onClick: vscode.EventEmitter<any>;
|
||||
}
|
||||
|
||||
@@ -34,16 +32,10 @@ export class ParentDialog extends ModelViewBase {
|
||||
export function createContext(): TestContext {
|
||||
|
||||
let viewTestContext = createViewContext();
|
||||
let languageExtensionService: mssql.ILanguageExtensionService = {
|
||||
listLanguages: () => { return Promise.resolve([]); },
|
||||
deleteLanguage: () => { return Promise.resolve(); },
|
||||
updateLanguage: () => { return Promise.resolve(); }
|
||||
};
|
||||
|
||||
return {
|
||||
apiWrapper: viewTestContext.apiWrapper,
|
||||
view: viewTestContext.view,
|
||||
languageExtensionService: languageExtensionService,
|
||||
onClick: viewTestContext.onClick
|
||||
};
|
||||
}
|
||||
|
||||
@@ -62,6 +62,9 @@ export function createViewContext(): ViewTestContext {
|
||||
onTextChanged: undefined!,
|
||||
onEnterKeyPressed: undefined!,
|
||||
value: ''
|
||||
});
|
||||
let image: () => azdata.ImageComponent = () => Object.assign({}, componentBase, {
|
||||
|
||||
});
|
||||
let dropdown: () => azdata.DropDownComponent = () => Object.assign({}, componentBase, {
|
||||
onValueChanged: onClick.event,
|
||||
@@ -124,6 +127,14 @@ export function createViewContext(): ViewTestContext {
|
||||
withProperties: () => inputBoxBuilder,
|
||||
withValidation: () => inputBoxBuilder
|
||||
};
|
||||
let imageBuilder: azdata.ComponentBuilder<azdata.ImageComponent> = {
|
||||
component: () => {
|
||||
let r = image();
|
||||
return r;
|
||||
},
|
||||
withProperties: () => imageBuilder,
|
||||
withValidation: () => imageBuilder
|
||||
};
|
||||
let dropdownBuilder: azdata.ComponentBuilder<azdata.DropDownComponent> = {
|
||||
component: () => {
|
||||
let r = dropdown();
|
||||
@@ -156,7 +167,7 @@ export function createViewContext(): ViewTestContext {
|
||||
editor: undefined!,
|
||||
diffeditor: undefined!,
|
||||
text: () => inputBoxBuilder,
|
||||
image: undefined!,
|
||||
image: () => imageBuilder,
|
||||
button: () => buttonBuilder,
|
||||
dropDown: () => dropdownBuilder,
|
||||
tree: undefined!,
|
||||
@@ -181,7 +192,7 @@ export function createViewContext(): ViewTestContext {
|
||||
try {
|
||||
await handler(view);
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
onValidityChanged: undefined!,
|
||||
@@ -242,7 +253,13 @@ export function createViewContext(): ViewTestContext {
|
||||
enabled: true,
|
||||
description: '',
|
||||
onValidityChanged: onClick.event,
|
||||
registerContent: () => { },
|
||||
registerContent: async (handler) => {
|
||||
try {
|
||||
await handler(view);
|
||||
} catch (err) {
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
modelView: undefined!,
|
||||
valid: true
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user