mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-15 10:58:31 -05:00
Machine Learning - Supporting multiple model import (#9869)
* Machine Learning Extension - Changed the deploy wizard to deploy multiple files
This commit is contained in:
@@ -248,9 +248,10 @@ describe('DeployedModelService', () => {
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object);
|
||||
testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => {
|
||||
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => {
|
||||
deployed = true;
|
||||
return Promise.resolve();
|
||||
return Promise.resolve(result);
|
||||
});
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
||||
@@ -259,7 +260,15 @@ describe('DeployedModelService', () => {
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
await should(service.deployLocalModel('', model)).resolved();
|
||||
let tempFilePath: string = '';
|
||||
try {
|
||||
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
await should(service.deployLocalModel(tempFilePath, model)).resolved();
|
||||
}
|
||||
finally {
|
||||
await utils.deleteFile(tempFilePath);
|
||||
}
|
||||
});
|
||||
|
||||
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
||||
@@ -306,7 +315,7 @@ describe('DeployedModelService', () => {
|
||||
CREATE TABLE [dbo].[ta[[b]]le](
|
||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[artifact_name] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NOT NULL,
|
||||
[group_path] [varchar](256) NULL,
|
||||
[artifact_content] [varbinary](max) NOT NULL,
|
||||
[artifact_initial_size] [bigint] NULL,
|
||||
[name] [varchar](256) NULL,
|
||||
@@ -345,7 +354,7 @@ describe('DeployedModelService', () => {
|
||||
should.deepEqual(expected, actual);
|
||||
});
|
||||
|
||||
it('getUpdateModelQuery should escape db name', async function (): Promise<void> {
|
||||
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const dbName = 'curre[n]tDb';
|
||||
const model: RegisteredModel =
|
||||
@@ -367,16 +376,17 @@ describe('DeployedModelService', () => {
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const expected = `
|
||||
UPDATE [dbo].[ta[[b]]le]
|
||||
SET
|
||||
name = 'title1',
|
||||
version = '1.1',
|
||||
description = 'desc1'
|
||||
WHERE artifact_id = 1`;
|
||||
const actual = service.getUpdateModelQuery(dbName, model);
|
||||
Insert into [dbo].[ta[[b]]le]
|
||||
(artifact_name, group_path, artifact_content, name, version, description)
|
||||
values (
|
||||
'name1',
|
||||
'ADS',
|
||||
,
|
||||
'title1',
|
||||
'1.1',
|
||||
'desc1')`;
|
||||
const actual = service.getInsertModelQuery(dbName, model);
|
||||
should.equal(actual.indexOf(expected) > 0, true);
|
||||
//should.deepEqual(actual, expected);
|
||||
|
||||
});
|
||||
|
||||
it('getModelContentQuery should escape db name', async function (): Promise<void> {
|
||||
|
||||
@@ -28,7 +28,7 @@ describe('Azure Models Component', () => {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
|
||||
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent);
|
||||
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent, false);
|
||||
view.registerComponent(testContext.view.modelBuilder);
|
||||
|
||||
let accounts: azdata.Account[] = [
|
||||
@@ -88,12 +88,15 @@ describe('Azure Models Component', () => {
|
||||
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
await view.refresh();
|
||||
testContext.onClick.fire();
|
||||
testContext.onClick.fire(true);
|
||||
should.notEqual(view.data, undefined);
|
||||
should.deepEqual(view.data?.account, accounts[0]);
|
||||
should.deepEqual(view.data?.subscription, subscriptions[0]);
|
||||
should.deepEqual(view.data?.group, groups[0]);
|
||||
should.deepEqual(view.data?.workspace, workspaces[0]);
|
||||
should.deepEqual(view.data?.model, models[0]);
|
||||
should.equal(view.data?.length, 1);
|
||||
if (view.data) {
|
||||
should.deepEqual(view.data[0].account, accounts[0]);
|
||||
should.deepEqual(view.data[0].subscription, subscriptions[0]);
|
||||
should.deepEqual(view.data[0].group, groups[0]);
|
||||
should.deepEqual(view.data[0].workspace, workspaces[0]);
|
||||
should.deepEqual(view.data[0].model, models[0]);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,7 +9,7 @@ import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import {
|
||||
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
|
||||
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName
|
||||
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType
|
||||
}
|
||||
from '../../../views/models/modelViewBase';
|
||||
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
@@ -164,9 +164,25 @@ describe('Predict Wizard', () => {
|
||||
view.on(DownloadRegisteredModelEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
|
||||
});
|
||||
if (view.modelBrowsePage) {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
|
||||
}
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data, undefined);
|
||||
|
||||
if (view.modelBrowsePage) {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
|
||||
}
|
||||
await view.refresh();
|
||||
testContext.onClick.fire();
|
||||
|
||||
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined);
|
||||
if (view.modelBrowsePage?.registeredModelsComponent?.data) {
|
||||
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1);
|
||||
}
|
||||
|
||||
|
||||
should.notEqual(await view.getModelFileName(), undefined);
|
||||
await view.columnsSelectionPage?.onEnter();
|
||||
|
||||
@@ -7,7 +7,7 @@ import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName } from '../../../views/models/modelViewBase';
|
||||
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType } from '../../../views/models/modelViewBase';
|
||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
@@ -97,6 +97,10 @@ describe('Register Model Wizard', () => {
|
||||
view.on(ListAzureModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
|
||||
if (view.modelBrowsePage) {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
|
||||
}
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data ,undefined);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
|
||||
@@ -32,8 +32,13 @@ export function createViewContext(): ViewTestContext {
|
||||
onDidClick: onClick.event
|
||||
});
|
||||
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
|
||||
checked: true,
|
||||
onDidClick: onClick.event
|
||||
});
|
||||
let checkbox: azdata.CheckBoxComponent = Object.assign({}, componentBase, {
|
||||
checked: true,
|
||||
onChanged: onClick.event
|
||||
});
|
||||
let container = {
|
||||
clearItems: () => { },
|
||||
addItems: () => { },
|
||||
@@ -58,6 +63,11 @@ export function createViewContext(): ViewTestContext {
|
||||
withProperties: () => radioButtonBuilder,
|
||||
withValidation: () => radioButtonBuilder
|
||||
};
|
||||
let checkBoxBuilder: azdata.ComponentBuilder<azdata.CheckBoxComponent> = {
|
||||
component: () => checkbox,
|
||||
withProperties: () => checkBoxBuilder,
|
||||
withValidation: () => checkBoxBuilder
|
||||
};
|
||||
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
|
||||
onTextChanged: undefined!,
|
||||
onEnterKeyPressed: undefined!,
|
||||
@@ -85,6 +95,12 @@ export function createViewContext(): ViewTestContext {
|
||||
component: undefined!
|
||||
});
|
||||
|
||||
let card: () => azdata.CardComponent = () => Object.assign({}, componentBase, {
|
||||
label: '',
|
||||
onDidActionClick: new vscode.EventEmitter<azdata.ActionDescriptor>().event,
|
||||
onCardSelectedChanged: onClick.event
|
||||
});
|
||||
|
||||
let declarativeTableBuilder: azdata.ComponentBuilder<azdata.DeclarativeTableComponent> = {
|
||||
component: () => declarativeTable(),
|
||||
withProperties: () => declarativeTableBuilder,
|
||||
@@ -127,6 +143,15 @@ export function createViewContext(): ViewTestContext {
|
||||
withProperties: () => inputBoxBuilder,
|
||||
withValidation: () => inputBoxBuilder
|
||||
};
|
||||
let cardBuilder: azdata.ComponentBuilder<azdata.CardComponent> = {
|
||||
component: () => {
|
||||
let r = card();
|
||||
return r;
|
||||
},
|
||||
withProperties: () => cardBuilder,
|
||||
withValidation: () => cardBuilder
|
||||
};
|
||||
|
||||
let imageBuilder: azdata.ComponentBuilder<azdata.ImageComponent> = {
|
||||
component: () => {
|
||||
let r = image();
|
||||
@@ -159,9 +184,9 @@ export function createViewContext(): ViewTestContext {
|
||||
flexContainer: () => flexBuilder,
|
||||
splitViewContainer: undefined!,
|
||||
dom: undefined!,
|
||||
card: undefined!,
|
||||
card: () => cardBuilder,
|
||||
inputBox: () => inputBoxBuilder,
|
||||
checkBox: undefined!,
|
||||
checkBox: () => checkBoxBuilder!,
|
||||
radioButton: () => radioButtonBuilder,
|
||||
webView: undefined!,
|
||||
editor: undefined!,
|
||||
|
||||
Reference in New Issue
Block a user