ML - Added model management dialog to edit and delete models (#10125)

* ML - Added model management dialog to edit and delete models
This commit is contained in:
Leila Lali
2020-04-24 08:33:35 -07:00
committed by GitHub
parent 7633c810aa
commit 1e8a9c47cb
34 changed files with 1184 additions and 462 deletions

View File

@@ -20,6 +20,7 @@ export const extensionOutputChannel = 'SQL Machine Learning';
export const notebookExtensionName = 'Microsoft.notebook'; export const notebookExtensionName = 'Microsoft.notebook';
export const azureSubscriptionsCommand = 'azure.accounts.getSubscriptions'; export const azureSubscriptionsCommand = 'azure.accounts.getSubscriptions';
export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups'; export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups';
export const signInToAzureCommand = 'azure.resource.signin';
// Tasks, commands // Tasks, commands
// //
@@ -57,6 +58,10 @@ export function confirmInstallPythonPackages(packages: string): string {
return localize('mls.installDependencies.confirmInstallPythonPackages' return localize('mls.installDependencies.confirmInstallPythonPackages'
, "The following Python packages are required to install: {0}. Are you sure you want to install?", packages); , "The following Python packages are required to install: {0}. Are you sure you want to install?", packages);
} }
export function confirmDeleteModel(modelName: string): string {
return localize('models.confirmDeleteModel'
, "Are you sure you want to delete model '{0}?", modelName);
}
export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ..."); export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ...");
export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed."); export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); } export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
@@ -114,16 +119,20 @@ export const extLangSelectedPath = localize('extLang.selectedPath', "Selected Pa
export const extLangInstallFailedError = localize('extLang.installFailedError', "Failed to install language"); export const extLangInstallFailedError = localize('extLang.installFailedError', "Failed to install language");
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language"); export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
export const modelArtifactName = localize('models.artifactName', "Artifact Name"); export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model");
export const databaseName = localize('databaseName', "Database name"); export const databaseName = localize('databaseName', "Database name");
export const tableName = localize('tableName', "Table name"); export const tableName = localize('tableName', "Table name");
export const modelName = localize('models.name', "Name"); export const modelName = localize('models.name', "Name");
export const modelFileName = localize('models.fileName', "File"); export const modelFileName = localize('models.fileName', "File");
export const modelDescription = localize('models.description', "Description"); export const modelDescription = localize('models.description', "Description");
export const modelCreated = localize('models.created', "Date Created"); export const modelCreated = localize('models.created', "Date created");
export const modelDeployed = localize('models.deployed', "Date deployed");
export const modelFramework = localize('models.framework', "Framework");
export const modelFrameworkVersion = localize('models.frameworkVersion', "Framework version");
export const modelVersion = localize('models.version', "Version"); export const modelVersion = localize('models.version', "Version");
export const browseModels = localize('models.browseButton', "..."); export const browseModels = localize('models.browseButton', "...");
export const azureAccount = localize('models.azureAccount', "Azure account"); export const azureAccount = localize('models.azureAccount', "Azure account");
export const azureSignIn = localize('models.azureSignIn', "Sign in to Azure");
export const columnDatabase = localize('predict.columnDatabase', "Target database"); export const columnDatabase = localize('predict.columnDatabase', "Target database");
export const columnTable = localize('predict.columnTable', "Target table"); export const columnTable = localize('predict.columnTable', "Target table");
export const inputColumns = localize('predict.inputColumns', "Model input mapping"); export const inputColumns = localize('predict.inputColumns', "Model input mapping");
@@ -151,15 +160,20 @@ export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy"
export const predictModel = localize('models.predictModel', "Predict"); export const predictModel = localize('models.predictModel', "Predict");
export const registerModelTitle = localize('models.RegisterWizard', "Import models"); export const registerModelTitle = localize('models.RegisterWizard', "Import models");
export const importModelTitle = localize('models.importModelTitle', "Import models"); export const importModelTitle = localize('models.importModelTitle', "Import models");
export const editModelTitle = localize('models.editModelTitle', "Edit model");
export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model"); export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model");
export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions"); export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions");
export const makePredictionDesc = localize('models.makePredictionDesc', "Generates a predicted value or scores using a managed model"); export const makePredictionDesc = localize('models.makePredictionDesc', "Generates a predicted value or scores using a managed model");
export const createNotebookTitle = localize('models.createNotebookTitle', "Create notebook"); export const createNotebookTitle = localize('models.createNotebookTitle', "Create notebook");
export const createNotebookDesc = localize('models.createNotebookDesc', "Run experiments and create models"); export const createNotebookDesc = localize('models.createNotebookDesc', "Run experiments and create models");
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully"); export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
export const modelUpdatedSuccessfully = localize('models.modelUpdatedSuccessfully', "Model updated successfully");
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register"); export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
export const localModelSource = localize('models.localModelSource', "File upload"); export const localModelSource = localize('models.localModelSource', "File upload");
export const localModelPageTitle = localize('models.localModelPageTitle', "Upload model file");
export const azureModelSource = localize('models.azureModelSource', "Azure Machine Learning"); export const azureModelSource = localize('models.azureModelSource', "Azure Machine Learning");
export const azureModelPageTitle = localize('models.azureModelPageTitle', "Import from Azure Machine Learning");
export const importedModelsPageTitle = localize('models.importedModelsPageTitle', "Select imported model");
export const registeredModelsSource = localize('models.registeredModelsSource', "Imported models"); export const registeredModelsSource = localize('models.registeredModelsSource', "Imported models");
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure"); export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource"); export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");

View File

@@ -8,7 +8,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as path from 'path'; import * as path from 'path';
import * as os from 'os'; import * as os from 'os';
import * as fs from 'fs'; import * as fs from 'fs';
import * as constants from '../common/constants'; import * as constants from './constants';
import { promisify } from 'util'; import { promisify } from 'util';
import { ApiWrapper } from './apiWrapper'; import { ApiWrapper } from './apiWrapper';

View File

@@ -134,6 +134,10 @@ export class AzureModelRegistryService {
this._modelClient = value; this._modelClient = value;
} }
public async signInToAzure(): Promise<void> {
await this._apiWrapper.executeCommand(constants.signInToAzureCommand);
}
/** /**
* Execute the background task to download the artifact * Execute the background task to download the artifact
*/ */

View File

@@ -9,9 +9,10 @@ import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils'; import * as utils from '../common/utils';
import { Config } from '../configurations/config'; import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner'; import { QueryRunner } from '../common/queryRunner';
import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces'; import { ImportedModel, ImportedModelDetails, ModelParameters } from './interfaces';
import { ModelPythonClient } from './modelPythonClient'; import { ModelPythonClient } from './modelPythonClient';
import * as constants from '../common/constants'; import * as constants from '../common/constants';
import * as queries from './queries';
import { DatabaseTable } from '../prediction/interfaces'; import { DatabaseTable } from '../prediction/interfaces';
import { ModelConfigRecent } from './modelConfigRecent'; import { ModelConfigRecent } from './modelConfigRecent';
@@ -34,14 +35,14 @@ export class DeployedModelService {
/** /**
* Returns deployed models * Returns deployed models
*/ */
public async getDeployedModels(table: DatabaseTable): Promise<RegisteredModel[]> { public async getDeployedModels(table: DatabaseTable): Promise<ImportedModel[]> {
let connection = await this.getCurrentConnection(); let connection = await this.getCurrentConnection();
let list: RegisteredModel[] = []; let list: ImportedModel[] = [];
if (!table.databaseName || !table.tableName || !table.schema) { if (!table.databaseName || !table.tableName || !table.schema) {
return []; return [];
} }
if (connection) { if (connection) {
const query = this.getDeployedModelsQuery(table); const query = queries.getDeployedModelsQuery(table);
let result = await this._queryRunner.safeRunQuery(connection, query); let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) { if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => { result.rows.forEach(row => {
@@ -58,10 +59,10 @@ export class DeployedModelService {
* Downloads model * Downloads model
* @param model model object * @param model model object
*/ */
public async downloadModel(model: RegisteredModel): Promise<string> { public async downloadModel(model: ImportedModel): Promise<string> {
let connection = await this.getCurrentConnection(); let connection = await this.getCurrentConnection();
if (connection) { if (connection) {
const query = this.getModelContentQuery(model); const query = queries.getModelContentQuery(model);
let result = await this._queryRunner.safeRunQuery(connection, query); let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) { if (result && result.rows && result.rows.length > 0) {
const content = result.rows[0][0].displayValue; const content = result.rows[0][0].displayValue;
@@ -86,29 +87,23 @@ export class DeployedModelService {
* @param filePath model file path * @param filePath model file path
* @param details model details * @param details model details
*/ */
public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined, table: DatabaseTable) { public async deployLocalModel(filePath: string, details: ImportedModelDetails | undefined, table: DatabaseTable) {
let connection = await this.getCurrentConnection(); let connection = await this.getCurrentConnection();
if (connection && table.databaseName) { if (connection && table.databaseName) {
await this.configureImport(connection, table); await this.configureImport(connection, table);
let currentModels = await this.getDeployedModels(table); let currentModels = await this.getDeployedModels(table);
const content = await utils.readFileInHex(filePath); const content = await utils.readFileInHex(filePath);
const fileName = details?.fileName || utils.getFileName(filePath); let modelToAdd: ImportedModel = Object.assign({}, {
let modelToAdd: RegisteredModel = {
id: 0, id: 0,
artifactName: fileName,
content: content, content: content,
title: details?.title || fileName,
description: details?.description,
version: details?.version,
table: table table: table
}; }, details);
await this._queryRunner.runWithDatabaseChange(connection, this.getInsertModelQuery(modelToAdd, table), table.databaseName); await this._queryRunner.runWithDatabaseChange(connection, queries.getInsertModelQuery(modelToAdd, table), table.databaseName);
let updatedModels = await this.getDeployedModels(table); let updatedModels = await this.getDeployedModels(table);
if (updatedModels.length < currentModels.length + 1) { if (updatedModels.length < currentModels.length + 1) {
throw Error(constants.importModelFailedError(details?.title, filePath)); throw Error(constants.importModelFailedError(details?.modelName, filePath));
} }
} else { } else {
@@ -116,12 +111,36 @@ export class DeployedModelService {
} }
} }
/**
* Updates a model
*/
public async updateModel(model: ImportedModel) {
let connection = await this.getCurrentConnection();
if (connection && model && model.table && model.table.databaseName) {
await this._queryRunner.runWithDatabaseChange(connection, queries.getUpdateModelQuery(model), model.table.databaseName);
} else {
throw new Error(constants.noConnectionError);
}
}
/**
* Updates a model
*/
public async deleteModel(model: ImportedModel) {
let connection = await this.getCurrentConnection();
if (connection && model && model.table && model.table.databaseName) {
await this._queryRunner.runWithDatabaseChange(connection, queries.getDeleteModelQuery(model), model.table.databaseName);
} else {
throw new Error(constants.noConnectionError);
}
}
public async configureImport(connection: azdata.connection.ConnectionProfile, table: DatabaseTable) { public async configureImport(connection: azdata.connection.ConnectionProfile, table: DatabaseTable) {
if (connection && table.databaseName) { if (connection && table.databaseName) {
let query = this.getDatabaseConfigureQuery(table); let query = queries.getDatabaseConfigureQuery(table);
await this._queryRunner.safeRunQuery(connection, query); await this._queryRunner.safeRunQuery(connection, query);
query = this.getConfigureTableQuery(table); query = queries.getConfigureTableQuery(table);
await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName); await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
} }
} }
@@ -140,7 +159,7 @@ export class DeployedModelService {
// If database exist verify the table schema // If database exist verify the table schema
// //
if ((await databases).find(x => x === table.databaseName)) { if ((await databases).find(x => x === table.databaseName)) {
const query = this.getConfigTableVerificationQuery(table); const query = queries.getConfigTableVerificationQuery(table);
const result = await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName); const result = await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
return result !== undefined && result.rows.length > 0 && result.rows[0][0].displayValue === '1'; return result !== undefined && result.rows.length > 0 && result.rows[0][0].displayValue === '1';
} else { } else {
@@ -178,14 +197,18 @@ export class DeployedModelService {
} }
} }
private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): RegisteredModel { private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): ImportedModel {
return { return {
id: +row[0].displayValue, id: +row[0].displayValue,
artifactName: row[1].displayValue, modelName: row[1].displayValue,
title: row[2].displayValue, description: row[2].displayValue,
description: row[3].displayValue, version: row[3].displayValue,
version: row[4].displayValue, created: row[4].displayValue,
created: row[5].displayValue, framework: row[5].displayValue,
frameworkVersion: row[6].displayValue,
deploymentTime: row[7].displayValue,
deployedBy: row[8].displayValue,
runId: row[9].displayValue,
table: table table: table
}; };
} }
@@ -193,160 +216,4 @@ export class DeployedModelService {
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> { private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection(); return await this._apiWrapper.getCurrentConnection();
} }
public getDatabaseConfigureQuery(configTable: DatabaseTable): string {
return `
IF NOT EXISTS (
SELECT name
FROM sys.databases
WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}'
)
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}]
`;
}
public getDeployedModelsQuery(table: DatabaseTable): string {
return `
SELECT artifact_id, artifact_name, name, description, version, created
FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')}
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
}
/**
* Verifies config table has the expected schema
* @param databaseName
* @param tableName
*/
public getConfigTableVerificationQuery(table: DatabaseTable): string {
let tableName = table.tableName;
let schemaName = table.schema;
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
return `
IF NOT EXISTS (
SELECT name
FROM sys.databases
WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}'
)
BEGIN
Select 1
END
ELSE
BEGIN
USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}]
IF EXISTS
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created')
BEGIN
Select 1
END
ELSE
BEGIN
Select 0
END
END
ELSE
select 1
END
`;
}
/**
* Update the table and adds extra columns (name, description, version) if doesn't already exist.
* Note: this code is temporary and will be removed weh the table supports the required schema
* @param databaseName
* @param tableName
*/
public getConfigureTableQuery(table: DatabaseTable): string {
let tableName = table.tableName;
let schemaName = table.schema;
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
return `
IF EXISTS
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name')
ALTER TABLE ${twoPartTableName} ADD [artifact_name] [varchar](256) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content')
ALTER TABLE ${twoPartTableName} ADD [artifact_content] [varbinary](max) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name')
ALTER TABLE ${twoPartTableName} ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version')
ALTER TABLE ${twoPartTableName} ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created')
BEGIN
ALTER TABLE ${twoPartTableName} ADD [created] [datetime] NULL
ALTER TABLE ${twoPartTableName} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
END
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='description')
ALTER TABLE ${twoPartTableName} ADD [description] [varchar](256) NULL
END
Else
BEGIN
CREATE TABLE ${twoPartTableName}(
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
[artifact_name] [varchar](256) NOT NULL,
[artifact_content] [varbinary](max) NOT NULL,
[artifact_initial_size] [bigint] NULL,
[name] [varchar](256) NULL,
[version] [varchar](256) NULL,
[created] [datetime] NULL,
[description] [varchar](256) NULL,
CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_artifact_pk] PRIMARY KEY CLUSTERED
(
[artifact_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[${utils.doubleEscapeSingleBrackets(tableName)}] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
END
`;
}
public getInsertModelQuery(model: RegisteredModel, table: DatabaseTable): string {
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '');
let updateScript = `
Insert into ${twoPartTableName}
(artifact_name, artifact_content, name, version, description)
values (
'${utils.doubleEscapeSingleQuotes(model.artifactName || '')}',
${utils.doubleEscapeSingleQuotes(model.content || '')},
'${utils.doubleEscapeSingleQuotes(model.title || '')}',
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
'${utils.doubleEscapeSingleQuotes(model.description || '')}')
`;
return `
${updateScript}
SELECT artifact_id, artifact_name, name, description, version, created
FROM ${threePartTableName}
WHERE artifact_id = SCOPE_IDENTITY();
`;
}
public getModelContentQuery(model: RegisteredModel): string {
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
return `
SELECT artifact_content
FROM ${threePartTableName}
WHERE artifact_id = ${model.id};
`;
}
} }

View File

@@ -47,11 +47,10 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
}; };
/** /**
* An interface representing registered model * An interface representing imported model
*/ */
export interface RegisteredModel extends RegisteredModelDetails { export interface ImportedModel extends ImportedModelDetails {
id: number; id: number;
artifactName: string;
content?: string; content?: string;
table: DatabaseTable; table: DatabaseTable;
} }
@@ -67,14 +66,19 @@ export interface ModelParameters {
} }
/** /**
* An interface representing registered model * An interface representing imported model
*/ */
export interface RegisteredModelDetails { export interface ImportedModelDetails {
title: string; modelName: string;
created?: string; created?: string;
deploymentTime?: string;
version?: string; version?: string;
description?: string; description?: string;
fileName?: string; fileName?: string;
framework?: string;
frameworkVersion?: string;
runId?: string;
deployedBy?: string;
} }
/** /**

View File

@@ -0,0 +1,195 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as utils from '../common/utils';
import { DatabaseTable } from '../prediction/interfaces';
import { ImportedModel } from './interfaces';
export function getDatabaseConfigureQuery(configTable: DatabaseTable): string {
return `
IF NOT EXISTS (
SELECT name
FROM sys.databases
WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}'
)
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}]
`;
}
export function getDeployedModelsQuery(table: DatabaseTable): string {
return `
${selectQuery}
FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')}
WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
ORDER BY model_id
`;
}
/**
* Verifies config table has the expected schema
* @param databaseName
* @param tableName
*/
export function getConfigTableVerificationQuery(table: DatabaseTable): string {
let tableName = table.tableName;
let schemaName = table.schema;
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
return `
IF NOT EXISTS (
SELECT name
FROM sys.databases
WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}'
)
BEGIN
SELECT 1
END
ELSE
BEGIN
USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}]
IF EXISTS
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_name')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_id')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_description')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework_version')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_version')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_creation_time')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_deployment_time')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='deployed_by')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='run_id')
BEGIN
SELECT 1
END
ELSE
BEGIN
SELECT 0
END
END
ELSE
SELECT 1
END
`;
}
/**
* Creates the import table if doesn't exist
*/
export function getConfigureTableQuery(table: DatabaseTable): string {
let tableName = table.tableName;
let schemaName = table.schema;
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
return `
IF NOT EXISTS
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
CREATE TABLE ${twoPartTableName}(
[model_id] [int] IDENTITY(1,1) NOT NULL,
[model_name] [varchar](256) NOT NULL,
[model_framework] [varchar](256) NULL,
[model_framework_version] [varchar](256) NULL,
[model] [varbinary](max) NOT NULL,
[model_version] [varchar](256) NULL,
[model_creation_time] [datetime2] NULL,
[model_deployment_time] [datetime2] NULL,
[deployed_by] [int] NULL,
[model_description] [varchar](256) NULL,
[run_id] [varchar](256) NULL,
CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_models_pk] PRIMARY KEY CLUSTERED
(
[model_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE ${twoPartTableName} ADD CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
END
`;
}
export function getInsertModelQuery(model: ImportedModel, table: DatabaseTable): string {
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '');
let updateScript = `
INSERT INTO ${twoPartTableName}
(model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id)
VALUES (
'${utils.doubleEscapeSingleQuotes(model.modelName || '')}',
${utils.doubleEscapeSingleQuotes(model.content || '')},
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
'${utils.doubleEscapeSingleQuotes(model.description || '')}',
'${utils.doubleEscapeSingleQuotes(model.created || '')}',
'${utils.doubleEscapeSingleQuotes(model.framework || '')}',
'${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
'${utils.doubleEscapeSingleQuotes(model.runId || '')}')
`;
return `
${updateScript}
${selectQuery}
FROM ${threePartTableName}
WHERE model_id = SCOPE_IDENTITY();
`;
}
export function getModelContentQuery(model: ImportedModel): string {
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
return `
SELECT model
FROM ${threePartTableName}
WHERE model_id = ${model.id};
`;
}
export function getUpdateModelQuery(model: ImportedModel): string {
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || '');
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
let updateScript = `
UPDATE ${twoPartTableName}
SET
model_name = '${utils.doubleEscapeSingleQuotes(model.modelName || '')}',
model_version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
model_description = '${utils.doubleEscapeSingleQuotes(model.description || '')}',
model_creation_time = '${utils.doubleEscapeSingleQuotes(model.created || '')}',
model_framework = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
model_framework_version = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
run_id = '${utils.doubleEscapeSingleQuotes(model.runId || '')}'
WHERE model_id = ${model.id}`;
return `
${updateScript}
${selectQuery}
FROM ${threePartTableName}
WHERE model_id = ${model.id};
`;
}
export function getDeleteModelQuery(model: ImportedModel): string {
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || '');
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
let updateScript = `
Delete from ${twoPartTableName}
WHERE model_id = ${model.id}`;
return `
${updateScript}
${selectQuery}
FROM ${threePartTableName}
`;
}
export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id';

View File

@@ -179,8 +179,8 @@ export class PackageManager {
let cmd = `"${this.pythonExecutable}" -m pip list --format=json`; let cmd = `"${this.pythonExecutable}" -m pip list --format=json`;
let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined); let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined);
let packagesResult: nbExtensionApis.IPackageDetails[] = []; let packagesResult: nbExtensionApis.IPackageDetails[] = [];
if (packagesInfo) { if (packagesInfo && packagesInfo.indexOf(']') > 0) {
packagesResult = <nbExtensionApis.IPackageDetails[]>JSON.parse(packagesInfo); packagesResult = <nbExtensionApis.IPackageDetails[]>JSON.parse(packagesInfo.substr(0, packagesInfo.indexOf(']') + 1));
} }
return packagesResult; return packagesResult;
} }

View File

@@ -8,7 +8,7 @@ import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper'; import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner'; import { QueryRunner } from '../common/queryRunner';
import * as utils from '../common/utils'; import * as utils from '../common/utils';
import { RegisteredModel } from '../modelManagement/interfaces'; import { ImportedModel } from '../modelManagement/interfaces';
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces'; import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
/** /**
@@ -42,7 +42,7 @@ export class PredictService {
*/ */
public async generatePredictScript( public async generatePredictScript(
predictParams: PredictParameters, predictParams: PredictParameters,
registeredModel: RegisteredModel | undefined, registeredModel: ImportedModel | undefined,
filePath: string | undefined filePath: string | undefined
): Promise<string> { ): Promise<string> {
let connection = await this.getCurrentConnection(); let connection = await this.getCurrentConnection();
@@ -146,9 +146,9 @@ WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuo
const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || ''); const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || '');
return ` return `
DECLARE @model VARBINARY(max) = ( DECLARE @model VARBINARY(max) = (
SELECT artifact_content SELECT model
FROM ${threePartTableName} FROM ${threePartTableName}
WHERE artifact_id = ${modelId} WHERE model_id = ${modelId}
); );
WITH predict_input WITH predict_input
AS ( AS (

View File

@@ -11,7 +11,7 @@ import * as should from 'should';
import { Config } from '../../configurations/config'; import { Config } from '../../configurations/config';
import { DeployedModelService } from '../../modelManagement/deployedModelService'; import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { QueryRunner } from '../../common/queryRunner'; import { QueryRunner } from '../../common/queryRunner';
import { RegisteredModel } from '../../modelManagement/interfaces'; import { ImportedModel } from '../../modelManagement/interfaces';
import { ModelPythonClient } from '../../modelManagement/modelPythonClient'; import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
import * as path from 'path'; import * as path from 'path';
import * as os from 'os'; import * as os from 'os';
@@ -19,6 +19,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as fs from 'fs'; import * as fs from 'fs';
import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent'; import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent';
import { DatabaseTable } from '../../prediction/interfaces'; import { DatabaseTable } from '../../prediction/interfaces';
import * as queries from '../../modelManagement/queries';
interface TestContext { interface TestContext {
@@ -70,14 +71,18 @@ describe('DeployedModelService', () => {
const testContext = createContext(); const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile(); const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const expected: RegisteredModel[] = [ const expected: ImportedModel[] = [
{ {
id: 1, id: 1,
artifactName: 'name1', modelName: 'name1',
title: 'title1',
description: 'desc1', description: 'desc1',
created: '2018-01-01', created: '2018-01-01',
deploymentTime: '2018-01-01',
version: '1.1', version: '1.1',
framework: 'onnx',
frameworkVersion: '1',
deployedBy: '1',
runId: 'run1',
table: testContext.importTable table: testContext.importTable
} }
@@ -97,11 +102,6 @@ describe('DeployedModelService', () => {
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
}, },
{
displayValue: 'title1',
isNull: false,
invariantCultureDisplayValue: ''
},
{ {
displayValue: 'desc1', displayValue: 'desc1',
isNull: false, isNull: false,
@@ -116,6 +116,31 @@ describe('DeployedModelService', () => {
displayValue: '2018-01-01', displayValue: '2018-01-01',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
},
{
displayValue: 'onnx',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'run1',
isNull: false,
invariantCultureDisplayValue: ''
} }
] ]
] ]
@@ -127,9 +152,6 @@ describe('DeployedModelService', () => {
testContext.modelClient.object, testContext.modelClient.object,
testContext.recentModels.object); testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
const actual = await service.getDeployedModels(testContext.importTable); const actual = await service.getDeployedModels(testContext.importTable);
should.deepEqual(actual, expected); should.deepEqual(actual, expected);
}); });
@@ -171,14 +193,18 @@ describe('DeployedModelService', () => {
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`); const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, 'test'); await fs.promises.writeFile(tempFilePath, 'test');
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const model: RegisteredModel = const model: ImportedModel =
{ {
id: 1, id: 1,
artifactName: 'name1', modelName: 'name1',
title: 'title1',
description: 'desc1', description: 'desc1',
created: '2018-01-01', created: '2018-01-01',
deploymentTime: '2018-01-01',
version: '1.1', version: '1.1',
framework: 'onnx',
frameworkVersion: '1',
deployedBy: '1',
runId: 'run1',
table: testContext.importTable table: testContext.importTable
}; };
const result = { const result = {
@@ -213,47 +239,72 @@ describe('DeployedModelService', () => {
const testContext = createContext(); const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile(); const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const model: RegisteredModel = const model: ImportedModel =
{ {
id: 1, id: 1,
artifactName: 'name1', modelName: 'name1',
title: 'title1',
description: 'desc1', description: 'desc1',
created: '2018-01-01', created: '2018-01-01',
deploymentTime: '2018-01-01',
version: '1.1', version: '1.1',
framework: 'onnx',
frameworkVersion: '1',
deployedBy: '1',
runId: 'run1',
table: testContext.importTable table: testContext.importTable
}; };
const row = [ const row = [
{ {
displayValue: '1', displayValue: '1',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
}, },
{ {
displayValue: 'name1', displayValue: 'name1',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
}, },
{ {
displayValue: 'title1', displayValue: 'desc1',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
}, },
{ {
displayValue: 'desc1', displayValue: '1.1',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
}, },
{ {
displayValue: '1.1', displayValue: '2018-01-01',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
}, },
{ {
displayValue: '2018-01-01', displayValue: 'onnx',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
} },
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'run1',
isNull: false,
invariantCultureDisplayValue: ''
}
]; ];
const result = { const result = {
rowCount: 1, rowCount: 1,
@@ -273,7 +324,7 @@ describe('DeployedModelService', () => {
testContext.modelClient.object, testContext.modelClient.object,
testContext.recentModels.object); testContext.recentModels.object);
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0), TypeMoq.It.isAny())).returns(() => { testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('INSERT INTO') > 0), TypeMoq.It.isAny())).returns(() => {
deployed = true; deployed = true;
return Promise.resolve(result); return Promise.resolve(result);
}); });
@@ -298,145 +349,105 @@ describe('DeployedModelService', () => {
it('getConfigureQuery should escape db name', async function (): Promise<void> { it('getConfigureQuery should escape db name', async function (): Promise<void> {
const testContext = createContext(); const testContext = createContext();
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
testContext.importTable.databaseName = 'd[]b'; testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le'; testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo'; testContext.importTable.schema = 'dbo';
const expected = ` const expected = `
IF EXISTS IF NOT EXISTS
( SELECT t.name, s.name ( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = 'ta[b]le' WHERE t.name = 'ta[b]le'
AND s.name = 'dbo' AND s.name = 'dbo'
) )
BEGIN
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_name')
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_name] [varchar](256) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_content')
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_content] [varbinary](max) NOT NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
ALTER TABLE [dbo].[ta[[b]]le] ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='created')
BEGIN
ALTER TABLE [dbo].[ta[[b]]le] ADD [created] [datetime] NULL
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
END
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='description')
ALTER TABLE [dbo].[ta[[b]]le] ADD [description] [varchar](256) NULL
END
Else
BEGIN BEGIN
CREATE TABLE [dbo].[ta[[b]]le]( CREATE TABLE [dbo].[ta[[b]]le](
[artifact_id] [int] IDENTITY(1,1) NOT NULL, [model_id] [int] IDENTITY(1,1) NOT NULL,
[artifact_name] [varchar](256) NOT NULL, [model_name] [varchar](256) NOT NULL,
[artifact_content] [varbinary](max) NOT NULL, [model_framework] [varchar](256) NULL,
[artifact_initial_size] [bigint] NULL, [model_framework_version] [varchar](256) NULL,
[name] [varchar](256) NULL, [model] [varbinary](max) NOT NULL,
[version] [varchar](256) NULL, [model_version] [varchar](256) NULL,
[created] [datetime] NULL, [model_creation_time] [datetime2] NULL,
[description] [varchar](256) NULL, [model_deployment_time] [datetime2] NULL,
CONSTRAINT [ta[[b]]le_artifact_pk] PRIMARY KEY CLUSTERED [deployed_by] [int] NULL,
[model_description] [varchar](256) NULL,
[run_id] [varchar](256) NULL,
CONSTRAINT [ta[[b]]le_models_pk] PRIMARY KEY CLUSTERED
( (
[artifact_id] ASC [model_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [ta[[b]]le_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
END END
`; `;
const actual = service.getConfigureTableQuery(testContext.importTable); const actual = queries.getConfigureTableQuery(testContext.importTable);
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`); should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
}); });
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> { it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
const testContext = createContext(); const testContext = createContext();
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
testContext.importTable.databaseName = 'd[]b'; testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le'; testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo'; testContext.importTable.schema = 'dbo';
const expected = ` const expected = `
SELECT artifact_id, artifact_name, name, description, version, created SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id
FROM [d[[]]b].[dbo].[ta[[b]]le] FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml' WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
Order by artifact_id ORDER BY model_id
`; `;
const actual = service.getDeployedModelsQuery(testContext.importTable); const actual = queries.getDeployedModelsQuery(testContext.importTable);
should.deepEqual(expected, actual); should.deepEqual(expected, actual);
}); });
it('getInsertModelQuery should escape db name', async function (): Promise<void> { it('getInsertModelQuery should escape db name', async function (): Promise<void> {
const testContext = createContext(); const testContext = createContext();
const model: RegisteredModel = const model: ImportedModel =
{ {
id: 1, id: 1,
artifactName: 'name1', modelName: 'name1',
title: 'title1',
description: 'desc1', description: 'desc1',
created: '2018-01-01', created: '2018-01-01',
version: '1.1', version: '1.1',
table: testContext.importTable table: testContext.importTable
}; };
let service = new DeployedModelService( const expected = `INSERT INTO [dbo].[tb]
testContext.apiWrapper.object, (model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id)
testContext.config.object, VALUES (
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
const expected = `
Insert into [dbo].[tb]
(artifact_name, artifact_content, name, version, description)
values (
'name1', 'name1',
, ,
'title1',
'1.1', '1.1',
'desc1')`; 'desc1',
const actual = service.getInsertModelQuery(model, testContext.importTable); '2018-01-01',
should.equal(actual.indexOf(expected) > 0, true); '',
'',
'')`;
const actual = queries.getInsertModelQuery(model, testContext.importTable);
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
}); });
it('getModelContentQuery should escape db name', async function (): Promise<void> { it('getModelContentQuery should escape db name', async function (): Promise<void> {
const testContext = createContext(); const testContext = createContext();
const model: RegisteredModel = const model: ImportedModel =
{ {
id: 1, id: 1,
artifactName: 'name1', modelName: 'name1',
title: 'title1',
description: 'desc1', description: 'desc1',
created: '2018-01-01', created: '2018-01-01',
version: '1.1', version: '1.1',
table: testContext.importTable table: testContext.importTable
}; };
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
model.table = { model.table = {
databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo' databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo'
}; };
const expected = ` const expected = `
SELECT artifact_content SELECT model
FROM [d[[]]b].[dbo].[ta[[b]]le] FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE artifact_id = 1; WHERE model_id = 1;
`; `;
const actual = service.getModelContentQuery(model); const actual = queries.getModelContentQuery(model);
should.deepEqual(actual, expected); should.deepEqual(actual, expected, `actual: ${actual} \n expected: ${expected}`);
}); });
}); });

View File

@@ -10,7 +10,7 @@ import * as TypeMoq from 'typemoq';
import * as should from 'should'; import * as should from 'should';
import { PredictService } from '../../prediction/predictService'; import { PredictService } from '../../prediction/predictService';
import { QueryRunner } from '../../common/queryRunner'; import { QueryRunner } from '../../common/queryRunner';
import { RegisteredModel } from '../../modelManagement/interfaces'; import { ImportedModel } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import * as path from 'path'; import * as path from 'path';
import * as os from 'os'; import * as os from 'os';
@@ -194,11 +194,10 @@ describe('PredictService', () => {
tableName: '', tableName: '',
schema: '' schema: ''
}; };
const model: RegisteredModel = const model: ImportedModel =
{ {
id: 1, id: 1,
artifactName: 'name1', modelName: 'name1',
title: 'title1',
description: 'desc1', description: 'desc1',
created: '2018-01-01', created: '2018-01-01',
version: '1.1', version: '1.1',

View File

@@ -8,12 +8,14 @@ import * as should from 'should';
import * as TypeMoq from 'typemoq'; import * as TypeMoq from 'typemoq';
import 'mocha'; import 'mocha';
import { createContext } from './utils'; import { createContext } from './utils';
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces'; import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource'; import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel } from '../../../modelManagement/interfaces'; import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { ModelManagementController } from '../../../views/models/modelManagementController'; import { ModelManagementController } from '../../../views/models/modelManagementController';
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces'; import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
import { DeleteModelEventName, UpdateModelEventName } from '../../../views/models/modelViewBase';
import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog';
const accounts: azdata.Account[] = [ const accounts: azdata.Account[] = [
{ {
@@ -55,11 +57,10 @@ const models: WorkspaceModel[] = [
name: 'model' name: 'model'
} }
]; ];
const localModels: RegisteredModel[] = [ const localModels: ImportedModel[] = [
{ {
id: 1, id: 1,
artifactName: 'model', modelName: 'model',
title: 'model',
table: { table: {
databaseName: 'db', databaseName: 'db',
tableName: 'tb', tableName: 'tb',
@@ -167,4 +168,35 @@ describe('Model Controller', () => {
const view = await controller.predictModel(); const view = await controller.predictModel();
should.notEqual(view, undefined); should.notEqual(view, undefined);
}); });
it('Should open edit model dialog successfully ', async function (): Promise<void> {
let testContext = createContext();
testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.deployModelService.setup(x => x.deleteModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
version: '1.1',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
};
const view = <EditModelDialog>await controller.editModel(model);
should.notEqual(view?.editModelPage, undefined);
if (view.editModelPage) {
view.editModelPage.sendRequest(UpdateModelEventName, model);
view.editModelPage.sendRequest(DeleteModelEventName, model);
}
testContext.deployModelService.verify(x => x.updateModel(model), TypeMoq.Times.atLeastOnce());
testContext.deployModelService.verify(x => x.deleteModel(model), TypeMoq.Times.atLeastOnce());
should.notEqual(view, undefined);
});
}); });

View File

@@ -0,0 +1,33 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog';
describe('Edit Model Dialog', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
version: '1.1',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
};
let view = new EditModelDialog(testContext.apiWrapper.object, '', undefined, model);
view.open();
should.notEqual(view.dialogView, undefined);
});
});

View File

@@ -12,7 +12,7 @@ import {
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType
} }
from '../../../views/models/modelViewBase'; from '../../../views/models/modelViewBase';
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces'; import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource'; import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase'; import { ViewBase } from '../../../views/viewBase';
@@ -80,11 +80,10 @@ describe('Predict Wizard', () => {
name: 'model' name: 'model'
} }
]; ];
let localModels: RegisteredModel[] = [ let localModels: ImportedModel[] = [
{ {
id: 1, id: 1,
artifactName: 'model', modelName: 'model',
title: 'model',
table: { table: {
databaseName: 'db', databaseName: 'db',
tableName: 'tb', tableName: 'tb',

View File

@@ -8,7 +8,7 @@ import * as should from 'should';
import 'mocha'; import 'mocha';
import { createContext } from './utils'; import { createContext } from './utils';
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase'; import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase';
import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ImportedModel } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource'; import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase'; import { ViewBase } from '../../../views/viewBase';
@@ -80,11 +80,10 @@ describe('Register Model Wizard', () => {
name: 'model' name: 'model'
} }
]; ];
let localModels: RegisteredModel[] = [ let localModels: ImportedModel[] = [
{ {
id: 1, id: 1,
artifactName: 'model', modelName: 'model',
title: 'model',
table: { table: {
databaseName: 'db', databaseName: 'db',
tableName: 'tb', tableName: 'tb',

View File

@@ -8,7 +8,7 @@ import 'mocha';
import { createContext } from './utils'; import { createContext } from './utils';
import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog'; import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog';
import { ListModelsEventName } from '../../../views/models/modelViewBase'; import { ListModelsEventName } from '../../../views/models/modelViewBase';
import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ImportedModel } from '../../../modelManagement/interfaces';
import { ViewBase } from '../../../views/viewBase'; import { ViewBase } from '../../../views/viewBase';
describe('Registered Models Dialog', () => { describe('Registered Models Dialog', () => {
@@ -27,11 +27,10 @@ describe('Registered Models Dialog', () => {
let view = new ManageModelsDialog(testContext.apiWrapper.object, ''); let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
view.open(); view.open();
let models: RegisteredModel[] = [ let models: ImportedModel[] = [
{ {
id: 1, id: 1,
artifactName: 'model', modelName: 'model',
title: '',
table: { table: {
databaseName: 'db', databaseName: 'db',
tableName: 'tb', tableName: 'tb',

View File

@@ -34,4 +34,10 @@ export interface AzureModelResource extends AzureWorkspaceResource {
model?: WorkspaceModel; model?: WorkspaceModel;
} }
export interface IComponentSettings {
multiSelect?: boolean;
editable?: boolean;
selectable?: boolean;
}

View File

@@ -10,11 +10,13 @@ import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
import { AzureModelsTable } from './azureModelsTable'; import { AzureModelsTable } from './azureModelsTable';
import { IDataComponent, AzureModelResource } from '../interfaces'; import { IDataComponent, AzureModelResource } from '../interfaces';
import { ModelArtifact } from './prediction/modelArtifact'; import { ModelArtifact } from './prediction/modelArtifact';
import { AzureSignInComponent } from './azureSignInComponent';
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource[]> { export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource[]> {
public azureModelsTable: AzureModelsTable | undefined; public azureModelsTable: AzureModelsTable | undefined;
public azureFilterComponent: AzureResourceFilterComponent | undefined; public azureFilterComponent: AzureResourceFilterComponent | undefined;
public azureSignInComponent: AzureSignInComponent | undefined;
private _loader: azdata.LoadingComponent | undefined; private _loader: azdata.LoadingComponent | undefined;
private _form: azdata.FormContainer | undefined; private _form: azdata.FormContainer | undefined;
@@ -34,6 +36,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this); this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this);
this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect); this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect);
this.azureSignInComponent = new AzureSignInComponent(this._apiWrapper, modelBuilder, this);
this._loader = modelBuilder.loadingComponent() this._loader = modelBuilder.loadingComponent()
.withItem(this.azureModelsTable.component) .withItem(this.azureModelsTable.component)
.withProperties({ .withProperties({
@@ -63,6 +66,20 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
} }
public addComponents(formBuilder: azdata.FormBuilder) { public addComponents(formBuilder: azdata.FormBuilder) {
this.removeComponents(formBuilder);
if (this.azureFilterComponent?.data?.account) {
this.addAzureComponents(formBuilder);
} else {
this.addAzureSignInComponents(formBuilder);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
this.removeAzureComponents(formBuilder);
this.removeAzureSignInComponents(formBuilder);
}
private addAzureComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) { if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.addComponents(formBuilder); this.azureFilterComponent.addComponents(formBuilder);
@@ -73,7 +90,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
} }
} }
public removeComponents(formBuilder: azdata.FormBuilder) { private removeAzureComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) { if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.removeComponents(formBuilder); this.azureFilterComponent.removeComponents(formBuilder);
formBuilder.removeFormItem({ formBuilder.removeFormItem({
@@ -83,6 +100,18 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
} }
} }
private addAzureSignInComponents(formBuilder: azdata.FormBuilder) {
if (this.azureSignInComponent) {
this.azureSignInComponent.addComponents(formBuilder);
}
}
private removeAzureSignInComponents(formBuilder: azdata.FormBuilder) {
if (this.azureSignInComponent) {
this.azureSignInComponent.removeComponents(formBuilder);
}
}
private async onLoading(): Promise<void> { private async onLoading(): Promise<void> {
if (this._loader) { if (this._loader) {
await this._loader.updateProperties({ loading: true }); await this._loader.updateProperties({ loading: true });

View File

@@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SignInToAzureEventName } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 300;
export class AzureSignInComponent extends ModelViewBase {
private _form: azdata.FormContainer;
private _signInButton: azdata.ButtonComponent;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._signInButton = this._modelBuilder.button().withProperties({
width: componentWidth,
label: constants.azureSignIn,
}).component();
this._signInButton.onDidClick(() => {
this.sendRequest(SignInToAzureEventName);
});
this._form = this._modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._signInButton
}]).component();
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._signInButton) {
formBuilder.addFormItems([{
title: constants.azureAccount,
component: this._signInButton
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._signInButton) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._signInButton
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component {
return this._form;
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -9,9 +9,9 @@ import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase'; import { ModelViewBase } from '../modelViewBase';
import { CurrentModelsTable } from './currentModelsTable'; import { CurrentModelsTable } from './currentModelsTable';
import { ApiWrapper } from '../../../common/apiWrapper'; import { ApiWrapper } from '../../../common/apiWrapper';
import { IPageView } from '../../interfaces'; import { IPageView, IComponentSettings } from '../../interfaces';
import { TableSelectionComponent } from '../tableSelectionComponent'; import { TableSelectionComponent } from '../tableSelectionComponent';
import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ImportedModel } from '../../../modelManagement/interfaces';
/** /**
* View to render current registered models * View to render current registered models
@@ -27,7 +27,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
* @param apiWrapper Creates new view * @param apiWrapper Creates new view
* @param parent page parent * @param parent page parent
*/ */
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = false) { constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
super(apiWrapper, parent.root, parent); super(apiWrapper, parent.root, parent);
} }
@@ -41,7 +41,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
this._tableSelectionComponent.onSelectedChanged(async () => { this._tableSelectionComponent.onSelectedChanged(async () => {
await this.onTableSelected(); await this.onTableSelected();
}); });
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect); this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings);
this._dataTable.registerComponent(modelBuilder); this._dataTable.registerComponent(modelBuilder);
this._tableComponent = this._dataTable.component; this._tableComponent = this._dataTable.component;
@@ -102,7 +102,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
} }
} }
public get data(): RegisteredModel[] | undefined { public get data(): ImportedModel[] | undefined {
return this._dataTable?.data; return this._dataTable?.data;
} }

View File

@@ -6,20 +6,21 @@
import * as azdata from 'azdata'; import * as azdata from 'azdata';
import * as vscode from 'vscode'; import * as vscode from 'vscode';
import * as constants from '../../../common/constants'; import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase'; import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper'; import { ApiWrapper } from '../../../common/apiWrapper';
import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ImportedModel } from '../../../modelManagement/interfaces';
import { IDataComponent } from '../../interfaces'; import { IDataComponent, IComponentSettings } from '../../interfaces';
import { ModelArtifact } from '../prediction/modelArtifact'; import { ModelArtifact } from '../prediction/modelArtifact';
import * as utils from '../../../common/utils';
/** /**
* View to render registered models table * View to render registered models table
*/ */
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<RegisteredModel[]> { export class CurrentModelsTable extends ModelViewBase implements IDataComponent<ImportedModel[]> {
private _table: azdata.DeclarativeTableComponent | undefined; private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined; private _modelBuilder: azdata.ModelBuilder | undefined;
private _selectedModel: RegisteredModel[] = []; private _selectedModel: ImportedModel[] = [];
private _loader: azdata.LoadingComponent | undefined; private _loader: azdata.LoadingComponent | undefined;
private _downloadedFile: ModelArtifact | undefined; private _downloadedFile: ModelArtifact | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>(); private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
@@ -28,7 +29,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
/** /**
* Creates new view * Creates new view
*/ */
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
super(apiWrapper, parent.root, parent); super(apiWrapper, parent.root, parent);
} }
@@ -38,62 +39,66 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
*/ */
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._modelBuilder = modelBuilder; this._modelBuilder = modelBuilder;
let columns = [
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
if (this._settings.editable) {
columns.push(
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
);
}
this._table = modelBuilder.declarativeTable() this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>( .withProperties<azdata.DeclarativeTableProperties>(
{ {
columns: [ columns: columns,
{ // Artifact name
displayName: constants.modelArtifactName,
ariaLabel: constants.modelArtifactName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [], data: [],
ariaLabel: constants.mlsConfigTitle ariaLabel: constants.mlsConfigTitle
}) })
@@ -132,7 +137,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
public async loadData(): Promise<void> { public async loadData(): Promise<void> {
await this.onLoading(); await this.onLoading();
if (this._table) { if (this._table) {
let models: RegisteredModel[] | undefined; let models: ImportedModel[] | undefined;
if (this.importTable) { if (this.importTable) {
models = await this.listModels(this.importTable); models = await this.listModels(this.importTable);
@@ -163,11 +168,28 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
} }
} }
private createTableRow(model: RegisteredModel): any[] { private createTableRow(model: ImportedModel): any[] {
let row: any[] = [model.modelName, model.created];
if (this._modelBuilder) { if (this._modelBuilder) {
let selectModelButton: azdata.Component; const selectButton = this.createSelectButton(model);
if (selectButton) {
row.push(selectButton);
}
const editButtons = this.createEditButtons(model);
if (editButtons && editButtons.length > 0) {
row = row.concat(editButtons);
}
}
return row;
}
private createSelectButton(model: ImportedModel): azdata.Component | undefined {
let selectModelButton: azdata.Component | undefined = undefined;
if (this._modelBuilder && this._settings.selectable) {
let onSelectItem = (checked: boolean) => { let onSelectItem = (checked: boolean) => {
if (!this._multiSelect) { if (!this._settings.multiSelect) {
this._selectedModel = []; this._selectedModel = [];
} }
const foundItem = this._selectedModel.find(x => x === model); const foundItem = this._selectedModel.find(x => x === model);
@@ -178,7 +200,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
} }
this.onModelSelected(); this.onModelSelected();
}; };
if (this._multiSelect) { if (this._settings.multiSelect) {
const checkbox = this._modelBuilder.checkBox().withProperties({ const checkbox = this._modelBuilder.checkBox().withProperties({
name: 'amlModel', name: 'amlModel',
value: model.id, value: model.id,
@@ -203,11 +225,53 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
}); });
selectModelButton = radioButton; selectModelButton = radioButton;
} }
return [model.artifactName, model.title, model.created, selectModelButton];
} }
return selectModelButton;
}
return []; private createEditButtons(model: ImportedModel): azdata.Component[] | undefined {
let dropButton: azdata.ButtonComponent | undefined = undefined;
let editButton: azdata.ButtonComponent | undefined = undefined;
if (this._modelBuilder && this._settings.editable) {
dropButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
width: 15,
height: 15
}).component();
dropButton.onDidClick(async () => {
try {
const confirm = await utils.promptConfirm(constants.confirmDeleteModel(model.modelName), this._apiWrapper);
if (confirm) {
await this.sendDataRequest(DeleteModelEventName, model);
if (this.parent) {
await this.parent?.refresh();
}
}
} catch (error) {
this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`);
}
});
editButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
width: 15,
height: 15
}).component();
editButton.onDidClick(async () => {
await this.sendDataRequest(EditModelEventName, model);
});
}
return editButton && dropButton ? [editButton, dropButton] : undefined;
} }
private async onModelSelected(): Promise<void> { private async onModelSelected(): Promise<void> {
@@ -221,7 +285,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
/** /**
* Returns selected data * Returns selected data
*/ */
public get data(): RegisteredModel[] | undefined { public get data(): ImportedModel[] | undefined {
return this._selectedModel; return this._selectedModel;
} }

View File

@@ -0,0 +1,75 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ModelViewBase, UpdateModelEventName } from '../modelViewBase';
import * as constants from '../../../common/constants';
import { ApiWrapper } from '../../../common/apiWrapper';
import { DialogView } from '../../dialogView';
import { ModelDetailsEditPage } from './modelDetailsEditPage';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* Dialog to render registered model views
*/
export class EditModelDialog extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string,
private _parentView: ModelViewBase | undefined,
private _model: ImportedModel) {
super(apiWrapper, root);
this.dialogView = new DialogView(this._apiWrapper);
}
public dialogView: DialogView;
public editModelPage: ModelDetailsEditPage | undefined;
/**
* Opens a dialog to edit models.
*/
public open(): void {
this.editModelPage = new ModelDetailsEditPage(this._apiWrapper, this, this._model);
let registerModelButton = this._apiWrapper.createButton(constants.extLangSaveButtonText);
registerModelButton.onClick(async () => {
if (this.editModelPage) {
const valid = await this.editModelPage.validate();
if (valid) {
try {
await this.sendDataRequest(UpdateModelEventName, this.editModelPage?.data);
this.showInfoMessage(constants.modelUpdatedSuccessfully);
if (this._parentView) {
await this._parentView.refresh();
}
} catch (error) {
this.showInfoMessage(`${constants.modelUpdateFailedError} ${constants.getErrorMessage(error)}`);
}
}
}
});
let dialog = this.dialogView.createDialog(constants.editModelTitle, [this.editModelPage]);
dialog.customButtons = [registerModelButton];
this.mainViewPanel = dialog;
dialog.okButton.hidden = true;
dialog.cancelButton.label = constants.extLangDoneButtonText;
dialog.registerCloseValidator(() => {
return false; // Blocks Enter key from closing dialog.
});
this._apiWrapper.openDialog(dialog);
}
/**
* Resets the tabs for given provider Id
*/
public async refresh(): Promise<void> {
if (this.dialogView) {
this.dialogView.refresh();
}
}
}

View File

@@ -14,7 +14,7 @@ import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage'; import { ModelSourcePage } from '../modelSourcePage';
import { ModelDetailsPage } from '../modelDetailsPage'; import { ModelDetailsPage } from '../modelDetailsPage';
import { ModelBrowsePage } from '../modelBrowsePage'; import { ModelBrowsePage } from '../modelBrowsePage';
import { ModelImportLocationPage } from './modelmportLocationPage'; import { ModelImportLocationPage } from './modelImportLocationPage';
/** /**
* Wizard to register a model * Wizard to register a model

View File

@@ -29,7 +29,10 @@ export class ManageModelsDialog extends ModelViewBase {
*/ */
public open(): void { public open(): void {
this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this); this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this, {
editable: true,
selectable: false
});
let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle); let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle);
registerModelButton.onClick(async () => { registerModelButton.onClick(async () => {

View File

@@ -0,0 +1,154 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* View to render filters to pick an azure resource
*/
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<ImportedModel> {
private _form: azdata.FormContainer | undefined;
private _nameComponent: azdata.InputBoxComponent | undefined;
private _descriptionComponent: azdata.InputBoxComponent | undefined;
private _createdComponent: azdata.Component | undefined;
private _deployedComponent: azdata.Component | undefined;
private _frameworkComponent: azdata.Component | undefined;
private _frameworkVersionComponent: azdata.Component | undefined;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._createdComponent = modelBuilder.text().withProperties({
value: this._model.created
}).component();
this._deployedComponent = modelBuilder.text().withProperties({
value: this._model.deploymentTime
}).component();
this._frameworkComponent = modelBuilder.text().withProperties({
value: this._model.framework
}).component();
this._frameworkVersionComponent = modelBuilder.text().withProperties({
value: this._model.frameworkVersion
}).component();
this._nameComponent = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength,
value: this._model.modelName
}).component();
this._descriptionComponent = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength,
value: this._model.description,
multiline: true,
height: 50
}).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._nameComponent
},
{
title: '',
component: this._descriptionComponent
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
formBuilder.addFormItems([{
title: constants.modelName,
component: this._nameComponent
}, {
title: constants.modelCreated,
component: this._createdComponent
},
{
title: constants.modelDeployed,
component: this._deployedComponent
}, {
title: constants.modelFramework,
component: this._frameworkComponent
}, {
title: constants.modelFrameworkVersion,
component: this._frameworkVersionComponent
}, {
title: constants.modelDescription,
component: this._descriptionComponent
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._createdComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._frameworkComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._frameworkVersionComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._deployedComponent
});
formBuilder.removeFormItem({
title: constants.modelName,
component: this._nameComponent
});
formBuilder.removeFormItem({
title: constants.modelDescription,
component: this._descriptionComponent
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Returns selected data
*/
public get data(): ImportedModel | undefined {
let model = Object.assign({}, this._model);
model.modelName = this._nameComponent?.value || '';
model.description = this._descriptionComponent?.value || '';
return model;
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,85 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { ModelDetailsComponent } from './modelDetailsComponent';
/**
* View to pick model source
*/
export class ModelDetailsEditPage extends ModelViewBase implements IPageView, IDataComponent<ImportedModel> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelDetailsComponent: ModelDetailsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelDetailsComponent = new ModelDetailsComponent(this._apiWrapper, this, this._model);
this.modelDetailsComponent.registerComponent(modelBuilder);
this.modelDetailsComponent.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ImportedModel | undefined {
return this.modelDetailsComponent?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this.modelDetailsComponent) {
await this.modelDetailsComponent.refresh();
}
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelImportTargetPageTitle;
}
public async disposePage(): Promise<void> {
}
public async validate(): Promise<boolean> {
let validated = false;
if (this.data?.modelName) {
validated = true;
} else {
this.showErrorMessage(constants.modelNameRequiredError);
}
return validated;
}
}

View File

@@ -44,7 +44,6 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView,
private async onTableSelected(): Promise<void> { private async onTableSelected(): Promise<void> {
if (this.tableSelectionComponent?.data) { if (this.tableSelectionComponent?.data) {
this.importTable = this.tableSelectionComponent?.data; this.importTable = this.tableSelectionComponent?.data;
//this.sendRequest(StoreImportTableEventName, this.importTable);
} }
} }

View File

@@ -19,7 +19,7 @@ import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> { export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
private _form: azdata.FormContainer | undefined; private _form: azdata.FormContainer | undefined;
private _title: string = constants.modelSourcePageTitle; private _title: string = constants.localModelPageTitle;
private _formBuilder: azdata.FormBuilder | undefined; private _formBuilder: azdata.FormBuilder | undefined;
public localModelsComponent: LocalModelsComponent | undefined; public localModelsComponent: LocalModelsComponent | undefined;
public azureModelsComponent: AzureModelsComponent | undefined; public azureModelsComponent: AzureModelsComponent | undefined;
@@ -40,7 +40,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
this.localModelsComponent.registerComponent(modelBuilder); this.localModelsComponent.registerComponent(modelBuilder);
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect); this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect);
this.azureModelsComponent.registerComponent(modelBuilder); this.azureModelsComponent.registerComponent(modelBuilder);
this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, this._multiSelect); this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, {
selectable: true,
multiSelect: this._multiSelect,
editable: false
});
this.registeredModelsComponent.registerComponent(modelBuilder); this.registeredModelsComponent.registerComponent(modelBuilder);
this.refresh(); this.refresh();
this._form = this._formBuilder.component(); this._form = this._formBuilder.component();
@@ -96,12 +100,12 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
private loadTitle(): void { private loadTitle(): void {
if (this.modelSourceType === ModelSourceType.Local) { if (this.modelSourceType === ModelSourceType.Local) {
this._title = 'Upload model file'; this._title = constants.localModelPageTitle;
} else if (this.modelSourceType === ModelSourceType.Azure) { } else if (this.modelSourceType === ModelSourceType.Azure) {
this._title = 'Import from Azure Machine Learning'; this._title = constants.azureModelPageTitle;
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) { } else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
this._title = 'Select imported model'; this._title = constants.importedModelsPageTitle;
} else { } else {
this._title = constants.modelSourcePageTitle; this._title = constants.modelSourcePageTitle;
} }
@@ -111,6 +115,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
* Returns page title * Returns page title
*/ */
public get title(): string { public get title(): string {
this.loadTitle();
return this._title; return this._title;
} }
@@ -144,7 +149,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
return { return {
modelData: x, modelData: x,
modelDetails: { modelDetails: {
title: fileName, modelName: fileName,
fileName: fileName fileName: fileName
}, },
targetImportTable: this.importTable targetImportTable: this.importTable
@@ -164,8 +169,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
model: x.model model: x.model
}, },
modelDetails: { modelDetails: {
title: x.model?.name || '', modelName: x.model?.name || '',
fileName: x.model?.name fileName: x.model?.name,
framework: x.model?.framework,
frameworkVersion: x.model?.frameworkVersion,
created: x.model?.createdTime
}, },
targetImportTable: this.importTable targetImportTable: this.importTable
}; };
@@ -178,7 +186,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
return { return {
modelData: x, modelData: x,
modelDetails: { modelDetails: {
title: '' modelName: ''
}, },
targetImportTable: this.importTable targetImportTable: this.importTable
}; };

View File

@@ -8,7 +8,7 @@ import { ModelViewBase, ModelViewData } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper'; import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants'; import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces'; import { IPageView, IDataComponent } from '../interfaces';
import { ModelDetailsComponent } from './modelDetailsComponent'; import { ModelsDetailsTableComponent } from './modelsDetailsTableComponent';
/** /**
* View to pick model details * View to pick model details
@@ -17,7 +17,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
private _form: azdata.FormContainer | undefined; private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined; private _formBuilder: azdata.FormBuilder | undefined;
public modelDetails: ModelDetailsComponent | undefined; public modelDetails: ModelsDetailsTableComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent); super(apiWrapper, parent.root, parent);
@@ -30,7 +30,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer(); this._formBuilder = modelBuilder.formContainer();
this.modelDetails = new ModelDetailsComponent(this._apiWrapper, modelBuilder, this); this.modelDetails = new ModelsDetailsTableComponent(this._apiWrapper, modelBuilder, this);
this.modelDetails.registerComponent(modelBuilder); this.modelDetails.registerComponent(modelBuilder);
this.modelDetails.addComponents(this._formBuilder); this.modelDetails.addComponents(this._formBuilder);
this.refresh(); this.refresh();
@@ -73,7 +73,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
} }
public validate(): Promise<boolean> { public validate(): Promise<boolean> {
if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.title)) { if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.modelName)) {
return Promise.resolve(true); return Promise.resolve(true);
} else { } else {
this.showErrorMessage(constants.modelNameRequiredError); this.showErrorMessage(constants.modelNameRequiredError);

View File

@@ -9,7 +9,7 @@ import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper'; import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService'; import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { RegisteredModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces'; import { ImportedModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { DeployedModelService } from '../../modelManagement/deployedModelService'; import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { ManageModelsDialog } from './manageModels/manageModelsDialog'; import { ManageModelsDialog } from './manageModels/manageModelsDialog';
@@ -17,7 +17,7 @@ import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName
} from './modelViewBase'; } from './modelViewBase';
import { ControllerBase } from '../controllerBase'; import { ControllerBase } from '../controllerBase';
import { ImportModelWizard } from './manageModels/importModelWizard'; import { ImportModelWizard } from './manageModels/importModelWizard';
@@ -26,6 +26,7 @@ import * as constants from '../../common/constants';
import { PredictWizard } from './prediction/predictWizard'; import { PredictWizard } from './prediction/predictWizard';
import { AzureModelResource } from '../interfaces'; import { AzureModelResource } from '../interfaces';
import { PredictService } from '../../prediction/predictService'; import { PredictService } from '../../prediction/predictService';
import { EditModelDialog } from './manageModels/editModelDialog';
/** /**
* Model management UI controller * Model management UI controller
@@ -71,6 +72,24 @@ export class ModelManagementController extends ControllerBase {
return view; return view;
} }
/**
* Opens the dialog to edit model
*/
public async editModel(model: ImportedModel, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
controller = controller || this;
apiWrapper = apiWrapper || this._apiWrapper;
root = root || this._root;
let view = new EditModelDialog(apiWrapper, root, parent, model);
controller.registerEvents(view);
// Open view
//
await view.open();
await view.refresh();
return view;
}
/** /**
* Opens the wizard for prediction * Opens the wizard for prediction
*/ */
@@ -136,6 +155,18 @@ export class ModelManagementController extends ControllerBase {
const importTable = <DatabaseTable>args; const importTable = <DatabaseTable>args;
await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root); await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root);
}); });
view.on(EditModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, EditModelEventName, this.editModel, model, view, this, this._apiWrapper, this._root);
});
view.on(UpdateModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, UpdateModelEventName, this.updateModel, this._registeredModelService, model);
});
view.on(DeleteModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, DeleteModelEventName, this.deleteModel, this._registeredModelService, model);
});
view.on(RegisterAzureModelEventName, async (arg) => { view.on(RegisterAzureModelEventName, async (arg) => {
let models = <ModelViewData[]>arg; let models = <ModelViewData[]>arg;
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService, await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
@@ -164,7 +195,7 @@ export class ModelManagementController extends ControllerBase {
predictArgs, predictArgs.model, predictArgs.filePath); predictArgs, predictArgs.model, predictArgs.filePath);
}); });
view.on(DownloadRegisteredModelEventName, async (arg) => { view.on(DownloadRegisteredModelEventName, async (arg) => {
let model = <RegisteredModel>arg; let model = <ImportedModel>arg;
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService, await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
model); model);
}); });
@@ -178,9 +209,13 @@ export class ModelManagementController extends ControllerBase {
await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService, await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService,
importTable); importTable);
}); });
view.on(SourceModelSelectedEventName, (arg) => { view.on(SourceModelSelectedEventName, async (arg) => {
view.modelSourceType = <ModelSourceType>arg; view.modelSourceType = <ModelSourceType>arg;
view.refresh(); await view.refresh();
});
view.on(SignInToAzureEventName, async () => {
await this.executeAction(view, SignInToAzureEventName, this.signInToAzure, this._amlService);
await view.refresh();
}); });
} }
@@ -206,6 +241,10 @@ export class ModelManagementController extends ControllerBase {
return view; return view;
} }
private async signInToAzure(service: AzureModelRegistryService): Promise<void> {
return await service.signInToAzure();
}
private async getAzureAccounts(service: AzureModelRegistryService): Promise<azdata.Account[]> { private async getAzureAccounts(service: AzureModelRegistryService): Promise<azdata.Account[]> {
return await service.getAccounts(); return await service.getAccounts();
} }
@@ -225,7 +264,7 @@ export class ModelManagementController extends ControllerBase {
return await service.getWorkspaces(account, subscription, group); return await service.getWorkspaces(account, subscription, group);
} }
private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise<RegisteredModel[]> { private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise<ImportedModel[]> {
return registeredModelService.getDeployedModels(table); return registeredModelService.getDeployedModels(table);
} }
@@ -258,6 +297,22 @@ export class ModelManagementController extends ControllerBase {
} }
} }
private async updateModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
if (model) {
await service.updateModel(model);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async deleteModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
if (model) {
await service.deleteModel(model);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async registerAzureModel( private async registerAzureModel(
azureService: AzureModelRegistryService, azureService: AzureModelRegistryService,
service: DeployedModelService, service: DeployedModelService,
@@ -306,7 +361,7 @@ export class ModelManagementController extends ControllerBase {
private async generatePredictScript( private async generatePredictScript(
predictService: PredictService, predictService: PredictService,
predictParams: PredictParameters, predictParams: PredictParameters,
registeredModel: RegisteredModel | undefined, registeredModel: ImportedModel | undefined,
filePath: string | undefined filePath: string | undefined
): Promise<string> { ): Promise<string> {
if (!predictParams) { if (!predictParams) {
@@ -334,7 +389,7 @@ export class ModelManagementController extends ControllerBase {
private async downloadRegisteredModel( private async downloadRegisteredModel(
registeredModelService: DeployedModelService, registeredModelService: DeployedModelService,
model: RegisteredModel | undefined): Promise<string> { model: ImportedModel | undefined): Promise<string> {
if (!model) { if (!model) {
throw Error(constants.invalidModelToPredictError); throw Error(constants.invalidModelToPredictError);
} }

View File

@@ -8,7 +8,7 @@ import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource'; import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper'; import { ApiWrapper } from '../../common/apiWrapper';
import { ViewBase } from '../viewBase'; import { ViewBase } from '../viewBase';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces'; import { ImportedModel, WorkspaceModel, ImportedModelDetails, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces'; import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
@@ -18,11 +18,11 @@ export interface AzureResourceEventArgs extends AzureWorkspaceResource {
} }
export interface RegisterModelEventArgs extends AzureWorkspaceResource { export interface RegisterModelEventArgs extends AzureWorkspaceResource {
details?: RegisteredModelDetails details?: ImportedModelDetails
} }
export interface PredictModelEventArgs extends PredictParameters { export interface PredictModelEventArgs extends PredictParameters {
model?: RegisteredModel; model?: ImportedModel;
filePath?: string; filePath?: string;
} }
@@ -35,8 +35,8 @@ export enum ModelSourceType {
export interface ModelViewData { export interface ModelViewData {
modelFile?: string; modelFile?: string;
modelData: AzureModelResource | string | RegisteredModel; modelData: AzureModelResource | string | ImportedModel;
modelDetails?: RegisteredModelDetails; modelDetails?: ImportedModelDetails;
targetImportTable?: DatabaseTable; targetImportTable?: DatabaseTable;
} }
@@ -57,10 +57,14 @@ export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel'; export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
export const PredictModelEventName = 'predictModel'; export const PredictModelEventName = 'predictModel';
export const RegisterModelEventName = 'registerModel'; export const RegisterModelEventName = 'registerModel';
export const EditModelEventName = 'editModel';
export const UpdateModelEventName = 'updateModel';
export const DeleteModelEventName = 'deleteModel';
export const SourceModelSelectedEventName = 'sourceModelSelected'; export const SourceModelSelectedEventName = 'sourceModelSelected';
export const LoadModelParametersEventName = 'loadModelParameters'; export const LoadModelParametersEventName = 'loadModelParameters';
export const StoreImportTableEventName = 'storeImportTable'; export const StoreImportTableEventName = 'storeImportTable';
export const VerifyImportTableEventName = 'verifyImportTable'; export const VerifyImportTableEventName = 'verifyImportTable';
export const SignInToAzureEventName = 'signInToAzure';
/** /**
* Base class for all model management views * Base class for all model management views
@@ -94,7 +98,11 @@ export abstract class ModelViewBase extends ViewBase {
DownloadRegisteredModelEventName, DownloadRegisteredModelEventName,
LoadModelParametersEventName, LoadModelParametersEventName,
StoreImportTableEventName, StoreImportTableEventName,
VerifyImportTableEventName]); VerifyImportTableEventName,
EditModelEventName,
UpdateModelEventName,
DeleteModelEventName,
SignInToAzureEventName]);
} }
/** /**
@@ -115,7 +123,7 @@ export abstract class ModelViewBase extends ViewBase {
/** /**
* list registered models * list registered models
*/ */
public async listModels(table: DatabaseTable): Promise<RegisteredModel[]> { public async listModels(table: DatabaseTable): Promise<ImportedModel[]> {
return await this.sendDataRequest(ListModelsEventName, table); return await this.sendDataRequest(ListModelsEventName, table);
} }
@@ -170,7 +178,7 @@ export abstract class ModelViewBase extends ViewBase {
* downloads registered model * downloads registered model
* @param model model to download * @param model model to download
*/ */
public async downloadRegisteredModel(model: RegisteredModel | undefined): Promise<string> { public async downloadRegisteredModel(model: ImportedModel | undefined): Promise<string> {
return await this.sendDataRequest(DownloadRegisteredModelEventName, model); return await this.sendDataRequest(DownloadRegisteredModelEventName, model);
} }
@@ -215,7 +223,7 @@ export abstract class ModelViewBase extends ViewBase {
* registers azure model * registers azure model
* @param args azure resource * @param args azure resource
*/ */
public async generatePredictScript(model: RegisteredModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> { public async generatePredictScript(model: ImportedModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> {
const args: PredictModelEventArgs = Object.assign({}, params, { const args: PredictModelEventArgs = Object.assign({}, params, {
model: model, model: model,
filePath: filePath, filePath: filePath,

View File

@@ -12,7 +12,7 @@ import { IDataComponent } from '../interfaces';
/** /**
* View to pick local models file * View to pick local models file
*/ */
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<ModelViewData[]> { export class ModelsDetailsTableComponent extends ModelViewBase implements IDataComponent<ModelViewData[]> {
private _table: azdata.DeclarativeTableComponent | undefined; private _table: azdata.DeclarativeTableComponent | undefined;
/** /**
@@ -127,7 +127,7 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone
private createTableRow(model: ModelViewData | undefined): any[] { private createTableRow(model: ModelViewData | undefined): any[] {
if (this._modelBuilder && model && model.modelDetails) { if (this._modelBuilder && model && model.modelDetails) {
const nameComponent = this._modelBuilder.inputBox().withProperties({ const nameComponent = this._modelBuilder.inputBox().withProperties({
value: model.modelDetails.title, value: model.modelDetails.modelName,
width: this.componentMaxLength - 100, width: this.componentMaxLength - 100,
required: true required: true
}).component(); }).component();
@@ -142,7 +142,7 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone
}); });
nameComponent.onTextChanged(() => { nameComponent.onTextChanged(() => {
if (model.modelDetails) { if (model.modelDetails) {
model.modelDetails.title = nameComponent.value || ''; model.modelDetails.modelName = nameComponent.value || '';
} }
}); });
let deleteButton = this._modelBuilder.button().withProperties({ let deleteButton = this._modelBuilder.button().withProperties({

View File

@@ -13,7 +13,7 @@ import * as constants from '../../../common/constants';
import { WizardView } from '../../wizardView'; import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage'; import { ModelSourcePage } from '../modelSourcePage';
import { ColumnsSelectionPage } from './columnsSelectionPage'; import { ColumnsSelectionPage } from './columnsSelectionPage';
import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ImportedModel } from '../../../modelManagement/interfaces';
import { ModelArtifact } from './modelArtifact'; import { ModelArtifact } from './modelArtifact';
import { ModelBrowsePage } from '../modelBrowsePage'; import { ModelBrowsePage } from '../modelBrowsePage';
@@ -124,7 +124,7 @@ export class PredictWizard extends ModelViewBase {
private async predict(): Promise<boolean> { private async predict(): Promise<boolean> {
try { try {
let modelFilePath: string | undefined; let modelFilePath: string | undefined;
let registeredModel: RegisteredModel | undefined = undefined; let registeredModel: ImportedModel | undefined = undefined;
if (this.modelResources && this.modelResources.data && this.modelResources.data === ModelSourceType.RegisteredModels if (this.modelResources && this.modelResources.data && this.modelResources.data === ModelSourceType.RegisteredModels
&& this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) { && this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) {
const data = this.modelBrowsePage?.registeredModelsComponent?.data; const data = this.modelBrowsePage?.registeredModelsComponent?.data;

View File

@@ -168,7 +168,7 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo
this._selectedTableName = this.getTableFullName(selectedTable); this._selectedTableName = this.getTableFullName(selectedTable);
this._tables.value = this.getTableFullName(selectedTable); this._tables.value = this.getTableFullName(selectedTable);
} else { } else {
this._selectedTableName = this.getTableFullName(this._tableNames[0]); this._selectedTableName = this._editable ? this.getTableFullName(this.importTable) : this.getTableFullName(this._tableNames[0]);
} }
} else { } else {
this._selectedTableName = this.getTableFullName(this._tableNames[0]); this._selectedTableName = this.getTableFullName(this._tableNames[0]);

View File

@@ -35,13 +35,14 @@ export class WizardView extends MainViewBase {
*/ */
public addWizardPage(page: IPageView, index: number): void { public addWizardPage(page: IPageView, index: number): void {
if (this._wizard) { if (this._wizard) {
this.addPage(page, index); const currentPage = this._wizard.currentPage;
this._wizard.removePage(index); if (page && currentPage < index) {
if (!page.viewPanel) { this.addPage(page, index);
this._wizard.removePage(index);
this.createWizardPage(page.title || '', page); this.createWizardPage(page.title || '', page);
this._wizard.addPage(<azdata.window.WizardPage>page.viewPanel, index);
this._wizard.setCurrentPage(currentPage);
} }
this._wizard.addPage(<azdata.window.WizardPage>page.viewPanel, index);
this._wizard.setCurrentPage(index);
} }
} }
@@ -109,4 +110,14 @@ export class WizardView extends MainViewBase {
public get wizard(): azdata.window.Wizard | undefined { public get wizard(): azdata.window.Wizard | undefined {
return this._wizard; return this._wizard;
} }
public async refresh(): Promise<void> {
for (let index = 0; index < this._pages.length; index++) {
const page = this._pages[index];
if (this._wizard?.pages[index]?.title !== page.title) {
this.addWizardPage(page, index);
}
}
await super.refresh();
}
} }