diff --git a/extensions/machine-learning-services/src/common/constants.ts b/extensions/machine-learning-services/src/common/constants.ts index e43283c6bc..3d777fd8d8 100644 --- a/extensions/machine-learning-services/src/common/constants.ts +++ b/extensions/machine-learning-services/src/common/constants.ts @@ -20,6 +20,7 @@ export const extensionOutputChannel = 'SQL Machine Learning'; export const notebookExtensionName = 'Microsoft.notebook'; export const azureSubscriptionsCommand = 'azure.accounts.getSubscriptions'; export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups'; +export const signInToAzureCommand = 'azure.resource.signin'; // Tasks, commands // @@ -57,6 +58,10 @@ export function confirmInstallPythonPackages(packages: string): string { return localize('mls.installDependencies.confirmInstallPythonPackages' , "The following Python packages are required to install: {0}. Are you sure you want to install?", packages); } +export function confirmDeleteModel(modelName: string): string { + return localize('models.confirmDeleteModel' + , "Are you sure you want to delete model '{0}?", modelName); +} export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ..."); export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed."); export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); } @@ -114,16 +119,20 @@ export const extLangSelectedPath = localize('extLang.selectedPath', "Selected Pa export const extLangInstallFailedError = localize('extLang.installFailedError', "Failed to install language"); export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language"); -export const modelArtifactName = localize('models.artifactName', "Artifact Name"); +export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model"); export const databaseName = localize('databaseName', "Database name"); export const tableName = localize('tableName', "Table name"); export const modelName = localize('models.name', "Name"); export const modelFileName = localize('models.fileName', "File"); export const modelDescription = localize('models.description', "Description"); -export const modelCreated = localize('models.created', "Date Created"); +export const modelCreated = localize('models.created', "Date created"); +export const modelDeployed = localize('models.deployed', "Date deployed"); +export const modelFramework = localize('models.framework', "Framework"); +export const modelFrameworkVersion = localize('models.frameworkVersion', "Framework version"); export const modelVersion = localize('models.version', "Version"); export const browseModels = localize('models.browseButton', "..."); export const azureAccount = localize('models.azureAccount', "Azure account"); +export const azureSignIn = localize('models.azureSignIn', "Sign in to Azure"); export const columnDatabase = localize('predict.columnDatabase', "Target database"); export const columnTable = localize('predict.columnTable', "Target table"); export const inputColumns = localize('predict.inputColumns', "Model input mapping"); @@ -151,15 +160,20 @@ export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy" export const predictModel = localize('models.predictModel', "Predict"); export const registerModelTitle = localize('models.RegisterWizard', "Import models"); export const importModelTitle = localize('models.importModelTitle', "Import models"); +export const editModelTitle = localize('models.editModelTitle', "Edit model"); export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model"); export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions"); export const makePredictionDesc = localize('models.makePredictionDesc', "Generates a predicted value or scores using a managed model"); export const createNotebookTitle = localize('models.createNotebookTitle', "Create notebook"); export const createNotebookDesc = localize('models.createNotebookDesc', "Run experiments and create models"); export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully"); +export const modelUpdatedSuccessfully = localize('models.modelUpdatedSuccessfully', "Model updated successfully"); export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register"); export const localModelSource = localize('models.localModelSource', "File upload"); +export const localModelPageTitle = localize('models.localModelPageTitle', "Upload model file"); export const azureModelSource = localize('models.azureModelSource', "Azure Machine Learning"); +export const azureModelPageTitle = localize('models.azureModelPageTitle', "Import from Azure Machine Learning"); +export const importedModelsPageTitle = localize('models.importedModelsPageTitle', "Select imported model"); export const registeredModelsSource = localize('models.registeredModelsSource', "Imported models"); export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure"); export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource"); diff --git a/extensions/machine-learning-services/src/common/utils.ts b/extensions/machine-learning-services/src/common/utils.ts index 892aad9f30..f4113c715d 100644 --- a/extensions/machine-learning-services/src/common/utils.ts +++ b/extensions/machine-learning-services/src/common/utils.ts @@ -8,7 +8,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid'; import * as path from 'path'; import * as os from 'os'; import * as fs from 'fs'; -import * as constants from '../common/constants'; +import * as constants from './constants'; import { promisify } from 'util'; import { ApiWrapper } from './apiWrapper'; diff --git a/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts b/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts index 47f24c0532..96128007cf 100644 --- a/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts +++ b/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts @@ -134,6 +134,10 @@ export class AzureModelRegistryService { this._modelClient = value; } + public async signInToAzure(): Promise { + await this._apiWrapper.executeCommand(constants.signInToAzureCommand); + } + /** * Execute the background task to download the artifact */ diff --git a/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts index 13c3542f9a..c70e3d41dd 100644 --- a/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts +++ b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts @@ -9,9 +9,10 @@ import { ApiWrapper } from '../common/apiWrapper'; import * as utils from '../common/utils'; import { Config } from '../configurations/config'; import { QueryRunner } from '../common/queryRunner'; -import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces'; +import { ImportedModel, ImportedModelDetails, ModelParameters } from './interfaces'; import { ModelPythonClient } from './modelPythonClient'; import * as constants from '../common/constants'; +import * as queries from './queries'; import { DatabaseTable } from '../prediction/interfaces'; import { ModelConfigRecent } from './modelConfigRecent'; @@ -34,14 +35,14 @@ export class DeployedModelService { /** * Returns deployed models */ - public async getDeployedModels(table: DatabaseTable): Promise { + public async getDeployedModels(table: DatabaseTable): Promise { let connection = await this.getCurrentConnection(); - let list: RegisteredModel[] = []; + let list: ImportedModel[] = []; if (!table.databaseName || !table.tableName || !table.schema) { return []; } if (connection) { - const query = this.getDeployedModelsQuery(table); + const query = queries.getDeployedModelsQuery(table); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { result.rows.forEach(row => { @@ -58,10 +59,10 @@ export class DeployedModelService { * Downloads model * @param model model object */ - public async downloadModel(model: RegisteredModel): Promise { + public async downloadModel(model: ImportedModel): Promise { let connection = await this.getCurrentConnection(); if (connection) { - const query = this.getModelContentQuery(model); + const query = queries.getModelContentQuery(model); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { const content = result.rows[0][0].displayValue; @@ -86,29 +87,23 @@ export class DeployedModelService { * @param filePath model file path * @param details model details */ - public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined, table: DatabaseTable) { + public async deployLocalModel(filePath: string, details: ImportedModelDetails | undefined, table: DatabaseTable) { let connection = await this.getCurrentConnection(); if (connection && table.databaseName) { await this.configureImport(connection, table); - let currentModels = await this.getDeployedModels(table); const content = await utils.readFileInHex(filePath); - const fileName = details?.fileName || utils.getFileName(filePath); - let modelToAdd: RegisteredModel = { + let modelToAdd: ImportedModel = Object.assign({}, { id: 0, - artifactName: fileName, content: content, - title: details?.title || fileName, - description: details?.description, - version: details?.version, table: table - }; - await this._queryRunner.runWithDatabaseChange(connection, this.getInsertModelQuery(modelToAdd, table), table.databaseName); + }, details); + await this._queryRunner.runWithDatabaseChange(connection, queries.getInsertModelQuery(modelToAdd, table), table.databaseName); let updatedModels = await this.getDeployedModels(table); if (updatedModels.length < currentModels.length + 1) { - throw Error(constants.importModelFailedError(details?.title, filePath)); + throw Error(constants.importModelFailedError(details?.modelName, filePath)); } } else { @@ -116,12 +111,36 @@ export class DeployedModelService { } } + /** + * Updates a model + */ + public async updateModel(model: ImportedModel) { + let connection = await this.getCurrentConnection(); + if (connection && model && model.table && model.table.databaseName) { + await this._queryRunner.runWithDatabaseChange(connection, queries.getUpdateModelQuery(model), model.table.databaseName); + } else { + throw new Error(constants.noConnectionError); + } + } + + /** + * Updates a model + */ + public async deleteModel(model: ImportedModel) { + let connection = await this.getCurrentConnection(); + if (connection && model && model.table && model.table.databaseName) { + await this._queryRunner.runWithDatabaseChange(connection, queries.getDeleteModelQuery(model), model.table.databaseName); + } else { + throw new Error(constants.noConnectionError); + } + } + public async configureImport(connection: azdata.connection.ConnectionProfile, table: DatabaseTable) { if (connection && table.databaseName) { - let query = this.getDatabaseConfigureQuery(table); + let query = queries.getDatabaseConfigureQuery(table); await this._queryRunner.safeRunQuery(connection, query); - query = this.getConfigureTableQuery(table); + query = queries.getConfigureTableQuery(table); await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName); } } @@ -140,7 +159,7 @@ export class DeployedModelService { // If database exist verify the table schema // if ((await databases).find(x => x === table.databaseName)) { - const query = this.getConfigTableVerificationQuery(table); + const query = queries.getConfigTableVerificationQuery(table); const result = await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName); return result !== undefined && result.rows.length > 0 && result.rows[0][0].displayValue === '1'; } else { @@ -178,14 +197,18 @@ export class DeployedModelService { } } - private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): RegisteredModel { + private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): ImportedModel { return { id: +row[0].displayValue, - artifactName: row[1].displayValue, - title: row[2].displayValue, - description: row[3].displayValue, - version: row[4].displayValue, - created: row[5].displayValue, + modelName: row[1].displayValue, + description: row[2].displayValue, + version: row[3].displayValue, + created: row[4].displayValue, + framework: row[5].displayValue, + frameworkVersion: row[6].displayValue, + deploymentTime: row[7].displayValue, + deployedBy: row[8].displayValue, + runId: row[9].displayValue, table: table }; } @@ -193,160 +216,4 @@ export class DeployedModelService { private async getCurrentConnection(): Promise { return await this._apiWrapper.getCurrentConnection(); } - - public getDatabaseConfigureQuery(configTable: DatabaseTable): string { - return ` - IF NOT EXISTS ( - SELECT name - FROM sys.databases - WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}' - ) - CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}] - `; - } - - public getDeployedModelsQuery(table: DatabaseTable): string { - return ` - SELECT artifact_id, artifact_name, name, description, version, created - FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')} - WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml' - Order by artifact_id - `; - } - - /** - * Verifies config table has the expected schema - * @param databaseName - * @param tableName - */ - public getConfigTableVerificationQuery(table: DatabaseTable): string { - let tableName = table.tableName; - let schemaName = table.schema; - const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); - - return ` - IF NOT EXISTS ( - SELECT name - FROM sys.databases - WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}' - ) - BEGIN - Select 1 - END - ELSE - BEGIN - USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}] - IF EXISTS - ( SELECT t.name, s.name - FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id - WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}' - AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}' - ) - BEGIN - IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name') - AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content') - AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name') - AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version') - AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created') - BEGIN - Select 1 - END - ELSE - BEGIN - Select 0 - END - END - ELSE - select 1 - END - `; - } - - /** - * Update the table and adds extra columns (name, description, version) if doesn't already exist. - * Note: this code is temporary and will be removed weh the table supports the required schema - * @param databaseName - * @param tableName - */ - public getConfigureTableQuery(table: DatabaseTable): string { - let tableName = table.tableName; - let schemaName = table.schema; - const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); - - return ` - IF EXISTS - ( SELECT t.name, s.name - FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id - WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}' - AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}' - ) - BEGIN - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name') - ALTER TABLE ${twoPartTableName} ADD [artifact_name] [varchar](256) NOT NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content') - ALTER TABLE ${twoPartTableName} ADD [artifact_content] [varbinary](max) NOT NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name') - ALTER TABLE ${twoPartTableName} ADD [name] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version') - ALTER TABLE ${twoPartTableName} ADD [version] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created') - BEGIN - ALTER TABLE ${twoPartTableName} ADD [created] [datetime] NULL - ALTER TABLE ${twoPartTableName} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created - END - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='description') - ALTER TABLE ${twoPartTableName} ADD [description] [varchar](256) NULL - END - Else - BEGIN - CREATE TABLE ${twoPartTableName}( - [artifact_id] [int] IDENTITY(1,1) NOT NULL, - [artifact_name] [varchar](256) NOT NULL, - [artifact_content] [varbinary](max) NOT NULL, - [artifact_initial_size] [bigint] NULL, - [name] [varchar](256) NULL, - [version] [varchar](256) NULL, - [created] [datetime] NULL, - [description] [varchar](256) NULL, - CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_artifact_pk] PRIMARY KEY CLUSTERED - ( - [artifact_id] ASC - )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] - ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] - ALTER TABLE [dbo].[${utils.doubleEscapeSingleBrackets(tableName)}] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] - END - `; - } - - public getInsertModelQuery(model: RegisteredModel, table: DatabaseTable): string { - const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); - const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || ''); - let updateScript = ` - Insert into ${twoPartTableName} - (artifact_name, artifact_content, name, version, description) - values ( - '${utils.doubleEscapeSingleQuotes(model.artifactName || '')}', - ${utils.doubleEscapeSingleQuotes(model.content || '')}, - '${utils.doubleEscapeSingleQuotes(model.title || '')}', - '${utils.doubleEscapeSingleQuotes(model.version || '')}', - '${utils.doubleEscapeSingleQuotes(model.description || '')}') - `; - - return ` - ${updateScript} - - SELECT artifact_id, artifact_name, name, description, version, created - FROM ${threePartTableName} - WHERE artifact_id = SCOPE_IDENTITY(); - `; - } - - public getModelContentQuery(model: RegisteredModel): string { - const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || ''); - return ` - SELECT artifact_content - FROM ${threePartTableName} - WHERE artifact_id = ${model.id}; - `; - } } diff --git a/extensions/machine-learning-services/src/modelManagement/interfaces.ts b/extensions/machine-learning-services/src/modelManagement/interfaces.ts index 03f143aa2e..8884401e33 100644 --- a/extensions/machine-learning-services/src/modelManagement/interfaces.ts +++ b/extensions/machine-learning-services/src/modelManagement/interfaces.ts @@ -47,11 +47,10 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & { }; /** - * An interface representing registered model + * An interface representing imported model */ -export interface RegisteredModel extends RegisteredModelDetails { +export interface ImportedModel extends ImportedModelDetails { id: number; - artifactName: string; content?: string; table: DatabaseTable; } @@ -67,14 +66,19 @@ export interface ModelParameters { } /** - * An interface representing registered model + * An interface representing imported model */ -export interface RegisteredModelDetails { - title: string; +export interface ImportedModelDetails { + modelName: string; created?: string; + deploymentTime?: string; version?: string; description?: string; fileName?: string; + framework?: string; + frameworkVersion?: string; + runId?: string; + deployedBy?: string; } /** diff --git a/extensions/machine-learning-services/src/modelManagement/queries.ts b/extensions/machine-learning-services/src/modelManagement/queries.ts new file mode 100644 index 0000000000..da19c7228f --- /dev/null +++ b/extensions/machine-learning-services/src/modelManagement/queries.ts @@ -0,0 +1,195 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as utils from '../common/utils'; +import { DatabaseTable } from '../prediction/interfaces'; +import { ImportedModel } from './interfaces'; + +export function getDatabaseConfigureQuery(configTable: DatabaseTable): string { + return ` + IF NOT EXISTS ( + SELECT name + FROM sys.databases + WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}' + ) + CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}] + `; +} + +export function getDeployedModelsQuery(table: DatabaseTable): string { + return ` + ${selectQuery} + FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')} + WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml' + ORDER BY model_id + `; +} + +/** + * Verifies config table has the expected schema + * @param databaseName + * @param tableName + */ +export function getConfigTableVerificationQuery(table: DatabaseTable): string { + let tableName = table.tableName; + let schemaName = table.schema; + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); + + return ` + IF NOT EXISTS ( + SELECT name + FROM sys.databases + WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}' + ) + BEGIN + SELECT 1 + END + ELSE + BEGIN + USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}] + IF EXISTS + ( SELECT t.name, s.name + FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id + WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}' + AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}' + ) + BEGIN + IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_name') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_id') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_description') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework_version') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_version') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_creation_time') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_deployment_time') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='deployed_by') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='run_id') + BEGIN + SELECT 1 + END + ELSE + BEGIN + SELECT 0 + END + END + ELSE + SELECT 1 + END + `; +} + +/** + * Creates the import table if doesn't exist + */ +export function getConfigureTableQuery(table: DatabaseTable): string { + let tableName = table.tableName; + let schemaName = table.schema; + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); + + return ` + IF NOT EXISTS + ( SELECT t.name, s.name + FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id + WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}' + AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}' + ) + BEGIN + CREATE TABLE ${twoPartTableName}( + [model_id] [int] IDENTITY(1,1) NOT NULL, + [model_name] [varchar](256) NOT NULL, + [model_framework] [varchar](256) NULL, + [model_framework_version] [varchar](256) NULL, + [model] [varbinary](max) NOT NULL, + [model_version] [varchar](256) NULL, + [model_creation_time] [datetime2] NULL, + [model_deployment_time] [datetime2] NULL, + [deployed_by] [int] NULL, + [model_description] [varchar](256) NULL, + [run_id] [varchar](256) NULL, + CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_models_pk] PRIMARY KEY CLUSTERED + ( + [model_id] ASC + )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] + ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] + ALTER TABLE ${twoPartTableName} ADD CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time] + END + `; +} + +export function getInsertModelQuery(model: ImportedModel, table: DatabaseTable): string { + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); + const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || ''); + let updateScript = ` + INSERT INTO ${twoPartTableName} + (model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id) + VALUES ( + '${utils.doubleEscapeSingleQuotes(model.modelName || '')}', + ${utils.doubleEscapeSingleQuotes(model.content || '')}, + '${utils.doubleEscapeSingleQuotes(model.version || '')}', + '${utils.doubleEscapeSingleQuotes(model.description || '')}', + '${utils.doubleEscapeSingleQuotes(model.created || '')}', + '${utils.doubleEscapeSingleQuotes(model.framework || '')}', + '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}', + '${utils.doubleEscapeSingleQuotes(model.runId || '')}') + `; + + return ` + ${updateScript} + ${selectQuery} + FROM ${threePartTableName} + WHERE model_id = SCOPE_IDENTITY(); + `; +} + +export function getModelContentQuery(model: ImportedModel): string { + const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || ''); + return ` + SELECT model + FROM ${threePartTableName} + WHERE model_id = ${model.id}; + `; +} + +export function getUpdateModelQuery(model: ImportedModel): string { + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || ''); + const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || ''); + let updateScript = ` + UPDATE ${twoPartTableName} + SET + model_name = '${utils.doubleEscapeSingleQuotes(model.modelName || '')}', + model_version = '${utils.doubleEscapeSingleQuotes(model.version || '')}', + model_description = '${utils.doubleEscapeSingleQuotes(model.description || '')}', + model_creation_time = '${utils.doubleEscapeSingleQuotes(model.created || '')}', + model_framework = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}', + model_framework_version = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}', + run_id = '${utils.doubleEscapeSingleQuotes(model.runId || '')}' + WHERE model_id = ${model.id}`; + + return ` + ${updateScript} + ${selectQuery} + FROM ${threePartTableName} + WHERE model_id = ${model.id}; + `; +} + +export function getDeleteModelQuery(model: ImportedModel): string { + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || ''); + const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || ''); + let updateScript = ` + Delete from ${twoPartTableName} + WHERE model_id = ${model.id}`; + + return ` + ${updateScript} + ${selectQuery} + FROM ${threePartTableName} + `; +} + +export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id'; + + diff --git a/extensions/machine-learning-services/src/packageManagement/packageManager.ts b/extensions/machine-learning-services/src/packageManagement/packageManager.ts index ec45279b29..54b0c23ecc 100644 --- a/extensions/machine-learning-services/src/packageManagement/packageManager.ts +++ b/extensions/machine-learning-services/src/packageManagement/packageManager.ts @@ -179,8 +179,8 @@ export class PackageManager { let cmd = `"${this.pythonExecutable}" -m pip list --format=json`; let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined); let packagesResult: nbExtensionApis.IPackageDetails[] = []; - if (packagesInfo) { - packagesResult = JSON.parse(packagesInfo); + if (packagesInfo && packagesInfo.indexOf(']') > 0) { + packagesResult = JSON.parse(packagesInfo.substr(0, packagesInfo.indexOf(']') + 1)); } return packagesResult; } diff --git a/extensions/machine-learning-services/src/prediction/predictService.ts b/extensions/machine-learning-services/src/prediction/predictService.ts index d60f28a663..4bd3ddc9fe 100644 --- a/extensions/machine-learning-services/src/prediction/predictService.ts +++ b/extensions/machine-learning-services/src/prediction/predictService.ts @@ -8,7 +8,7 @@ import * as azdata from 'azdata'; import { ApiWrapper } from '../common/apiWrapper'; import { QueryRunner } from '../common/queryRunner'; import * as utils from '../common/utils'; -import { RegisteredModel } from '../modelManagement/interfaces'; +import { ImportedModel } from '../modelManagement/interfaces'; import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces'; /** @@ -42,7 +42,7 @@ export class PredictService { */ public async generatePredictScript( predictParams: PredictParameters, - registeredModel: RegisteredModel | undefined, + registeredModel: ImportedModel | undefined, filePath: string | undefined ): Promise { let connection = await this.getCurrentConnection(); @@ -146,9 +146,9 @@ WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuo const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || ''); return ` DECLARE @model VARBINARY(max) = ( - SELECT artifact_content + SELECT model FROM ${threePartTableName} - WHERE artifact_id = ${modelId} + WHERE model_id = ${modelId} ); WITH predict_input AS ( diff --git a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts index 3daf6e6850..ef0761a3c3 100644 --- a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts +++ b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts @@ -11,7 +11,7 @@ import * as should from 'should'; import { Config } from '../../configurations/config'; import { DeployedModelService } from '../../modelManagement/deployedModelService'; import { QueryRunner } from '../../common/queryRunner'; -import { RegisteredModel } from '../../modelManagement/interfaces'; +import { ImportedModel } from '../../modelManagement/interfaces'; import { ModelPythonClient } from '../../modelManagement/modelPythonClient'; import * as path from 'path'; import * as os from 'os'; @@ -19,6 +19,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid'; import * as fs from 'fs'; import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent'; import { DatabaseTable } from '../../prediction/interfaces'; +import * as queries from '../../modelManagement/queries'; interface TestContext { @@ -70,14 +71,18 @@ describe('DeployedModelService', () => { const testContext = createContext(); const connection = new azdata.connection.ConnectionProfile(); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - const expected: RegisteredModel[] = [ + const expected: ImportedModel[] = [ { id: 1, - artifactName: 'name1', - title: 'title1', + modelName: 'name1', description: 'desc1', created: '2018-01-01', + deploymentTime: '2018-01-01', version: '1.1', + framework: 'onnx', + frameworkVersion: '1', + deployedBy: '1', + runId: 'run1', table: testContext.importTable } @@ -97,11 +102,6 @@ describe('DeployedModelService', () => { isNull: false, invariantCultureDisplayValue: '' }, - { - displayValue: 'title1', - isNull: false, - invariantCultureDisplayValue: '' - }, { displayValue: 'desc1', isNull: false, @@ -116,6 +116,31 @@ describe('DeployedModelService', () => { displayValue: '2018-01-01', isNull: false, invariantCultureDisplayValue: '' + }, + { + displayValue: 'onnx', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '2018-01-01', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'run1', + isNull: false, + invariantCultureDisplayValue: '' } ] ] @@ -127,9 +152,6 @@ describe('DeployedModelService', () => { testContext.modelClient.object, testContext.recentModels.object); testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); - - testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); - testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table'); const actual = await service.getDeployedModels(testContext.importTable); should.deepEqual(actual, expected); }); @@ -171,14 +193,18 @@ describe('DeployedModelService', () => { const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`); await fs.promises.writeFile(tempFilePath, 'test'); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - const model: RegisteredModel = + const model: ImportedModel = { id: 1, - artifactName: 'name1', - title: 'title1', + modelName: 'name1', description: 'desc1', created: '2018-01-01', + deploymentTime: '2018-01-01', version: '1.1', + framework: 'onnx', + frameworkVersion: '1', + deployedBy: '1', + runId: 'run1', table: testContext.importTable }; const result = { @@ -213,47 +239,72 @@ describe('DeployedModelService', () => { const testContext = createContext(); const connection = new azdata.connection.ConnectionProfile(); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - const model: RegisteredModel = + const model: ImportedModel = { id: 1, - artifactName: 'name1', - title: 'title1', + modelName: 'name1', description: 'desc1', created: '2018-01-01', + deploymentTime: '2018-01-01', version: '1.1', + framework: 'onnx', + frameworkVersion: '1', + deployedBy: '1', + runId: 'run1', table: testContext.importTable + }; const row = [ - { - displayValue: '1', - isNull: false, - invariantCultureDisplayValue: '' - }, - { - displayValue: 'name1', - isNull: false, - invariantCultureDisplayValue: '' - }, - { - displayValue: 'title1', - isNull: false, - invariantCultureDisplayValue: '' - }, - { - displayValue: 'desc1', - isNull: false, - invariantCultureDisplayValue: '' - }, - { - displayValue: '1.1', - isNull: false, - invariantCultureDisplayValue: '' - }, - { - displayValue: '2018-01-01', - isNull: false, - invariantCultureDisplayValue: '' - } + { + displayValue: '1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'name1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'desc1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '1.1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '2018-01-01', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'onnx', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '2018-01-01', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'run1', + isNull: false, + invariantCultureDisplayValue: '' + } ]; const result = { rowCount: 1, @@ -273,7 +324,7 @@ describe('DeployedModelService', () => { testContext.modelClient.object, testContext.recentModels.object); - testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0), TypeMoq.It.isAny())).returns(() => { + testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('INSERT INTO') > 0), TypeMoq.It.isAny())).returns(() => { deployed = true; return Promise.resolve(result); }); @@ -298,145 +349,105 @@ describe('DeployedModelService', () => { it('getConfigureQuery should escape db name', async function (): Promise { const testContext = createContext(); - let service = new DeployedModelService( - testContext.apiWrapper.object, - testContext.config.object, - testContext.queryRunner.object, - testContext.modelClient.object, - testContext.recentModels.object); testContext.importTable.databaseName = 'd[]b'; testContext.importTable.tableName = 'ta[b]le'; testContext.importTable.schema = 'dbo'; const expected = ` - IF EXISTS + IF NOT EXISTS ( SELECT t.name, s.name FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id WHERE t.name = 'ta[b]le' AND s.name = 'dbo' ) - BEGIN - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_name') - ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_name] [varchar](256) NOT NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_content') - ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_content] [varbinary](max) NOT NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name') - ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version') - ALTER TABLE [dbo].[ta[[b]]le] ADD [version] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='created') - BEGIN - ALTER TABLE [dbo].[ta[[b]]le] ADD [created] [datetime] NULL - ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created - END - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='description') - ALTER TABLE [dbo].[ta[[b]]le] ADD [description] [varchar](256) NULL - END - Else BEGIN CREATE TABLE [dbo].[ta[[b]]le]( - [artifact_id] [int] IDENTITY(1,1) NOT NULL, - [artifact_name] [varchar](256) NOT NULL, - [artifact_content] [varbinary](max) NOT NULL, - [artifact_initial_size] [bigint] NULL, - [name] [varchar](256) NULL, - [version] [varchar](256) NULL, - [created] [datetime] NULL, - [description] [varchar](256) NULL, - CONSTRAINT [ta[[b]]le_artifact_pk] PRIMARY KEY CLUSTERED + [model_id] [int] IDENTITY(1,1) NOT NULL, + [model_name] [varchar](256) NOT NULL, + [model_framework] [varchar](256) NULL, + [model_framework_version] [varchar](256) NULL, + [model] [varbinary](max) NOT NULL, + [model_version] [varchar](256) NULL, + [model_creation_time] [datetime2] NULL, + [model_deployment_time] [datetime2] NULL, + [deployed_by] [int] NULL, + [model_description] [varchar](256) NULL, + [run_id] [varchar](256) NULL, + CONSTRAINT [ta[[b]]le_models_pk] PRIMARY KEY CLUSTERED ( - [artifact_id] ASC + [model_id] ASC )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] - ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] + ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [ta[[b]]le_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time] END `; - const actual = service.getConfigureTableQuery(testContext.importTable); + const actual = queries.getConfigureTableQuery(testContext.importTable); should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`); }); it('getDeployedModelsQuery should escape db name', async function (): Promise { const testContext = createContext(); - let service = new DeployedModelService( - testContext.apiWrapper.object, - testContext.config.object, - testContext.queryRunner.object, - testContext.modelClient.object, - testContext.recentModels.object); testContext.importTable.databaseName = 'd[]b'; testContext.importTable.tableName = 'ta[b]le'; testContext.importTable.schema = 'dbo'; const expected = ` - SELECT artifact_id, artifact_name, name, description, version, created + SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id FROM [d[[]]b].[dbo].[ta[[b]]le] - WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml' - Order by artifact_id + WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml' + ORDER BY model_id `; - const actual = service.getDeployedModelsQuery(testContext.importTable); + const actual = queries.getDeployedModelsQuery(testContext.importTable); should.deepEqual(expected, actual); }); it('getInsertModelQuery should escape db name', async function (): Promise { const testContext = createContext(); - const model: RegisteredModel = + const model: ImportedModel = { id: 1, - artifactName: 'name1', - title: 'title1', + modelName: 'name1', description: 'desc1', created: '2018-01-01', version: '1.1', table: testContext.importTable }; - let service = new DeployedModelService( - testContext.apiWrapper.object, - testContext.config.object, - testContext.queryRunner.object, - testContext.modelClient.object, - testContext.recentModels.object); - - const expected = ` - Insert into [dbo].[tb] - (artifact_name, artifact_content, name, version, description) - values ( + const expected = `INSERT INTO [dbo].[tb] + (model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id) + VALUES ( 'name1', , - 'title1', '1.1', - 'desc1')`; - const actual = service.getInsertModelQuery(model, testContext.importTable); - should.equal(actual.indexOf(expected) > 0, true); + 'desc1', + '2018-01-01', + '', + '', + '')`; + const actual = queries.getInsertModelQuery(model, testContext.importTable); + should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`); }); it('getModelContentQuery should escape db name', async function (): Promise { const testContext = createContext(); - const model: RegisteredModel = + const model: ImportedModel = { id: 1, - artifactName: 'name1', - title: 'title1', + modelName: 'name1', description: 'desc1', created: '2018-01-01', version: '1.1', table: testContext.importTable }; - let service = new DeployedModelService( - testContext.apiWrapper.object, - testContext.config.object, - testContext.queryRunner.object, - testContext.modelClient.object, - testContext.recentModels.object); model.table = { databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo' }; const expected = ` - SELECT artifact_content + SELECT model FROM [d[[]]b].[dbo].[ta[[b]]le] - WHERE artifact_id = 1; + WHERE model_id = 1; `; - const actual = service.getModelContentQuery(model); - should.deepEqual(actual, expected); + const actual = queries.getModelContentQuery(model); + should.deepEqual(actual, expected, `actual: ${actual} \n expected: ${expected}`); }); }); diff --git a/extensions/machine-learning-services/src/test/prediction/predictService.test.ts b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts index 442a2e6fa0..3a2effe744 100644 --- a/extensions/machine-learning-services/src/test/prediction/predictService.test.ts +++ b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts @@ -10,7 +10,7 @@ import * as TypeMoq from 'typemoq'; import * as should from 'should'; import { PredictService } from '../../prediction/predictService'; import { QueryRunner } from '../../common/queryRunner'; -import { RegisteredModel } from '../../modelManagement/interfaces'; +import { ImportedModel } from '../../modelManagement/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import * as path from 'path'; import * as os from 'os'; @@ -194,11 +194,10 @@ describe('PredictService', () => { tableName: '', schema: '' }; - const model: RegisteredModel = + const model: ImportedModel = { id: 1, - artifactName: 'name1', - title: 'title1', + modelName: 'name1', description: 'desc1', created: '2018-01-01', version: '1.1', diff --git a/extensions/machine-learning-services/src/test/views/models/ModelManagementController.test.ts b/extensions/machine-learning-services/src/test/views/models/ModelManagementController.test.ts index 3202b9cb3a..a54178217e 100644 --- a/extensions/machine-learning-services/src/test/views/models/ModelManagementController.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/ModelManagementController.test.ts @@ -8,12 +8,14 @@ import * as should from 'should'; import * as TypeMoq from 'typemoq'; import 'mocha'; import { createContext } from './utils'; -import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces'; +import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces'; import { azureResource } from '../../../typings/azure-resource'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { WorkspaceModel } from '../../../modelManagement/interfaces'; import { ModelManagementController } from '../../../views/models/modelManagementController'; import { DatabaseTable, TableColumn } from '../../../prediction/interfaces'; +import { DeleteModelEventName, UpdateModelEventName } from '../../../views/models/modelViewBase'; +import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog'; const accounts: azdata.Account[] = [ { @@ -55,11 +57,10 @@ const models: WorkspaceModel[] = [ name: 'model' } ]; -const localModels: RegisteredModel[] = [ +const localModels: ImportedModel[] = [ { id: 1, - artifactName: 'model', - title: 'model', + modelName: 'model', table: { databaseName: 'db', tableName: 'tb', @@ -167,4 +168,35 @@ describe('Model Controller', () => { const view = await controller.predictModel(); should.notEqual(view, undefined); }); + + it('Should open edit model dialog successfully ', async function (): Promise { + let testContext = createContext(); + testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.deployModelService.setup(x => x.deleteModel(TypeMoq.It.isAny())).returns(() => Promise.resolve()); + + let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object); + const model: ImportedModel = + { + id: 1, + modelName: 'name1', + description: 'desc1', + created: '2018-01-01', + version: '1.1', + table: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } + }; + const view = await controller.editModel(model); + should.notEqual(view?.editModelPage, undefined); + if (view.editModelPage) { + view.editModelPage.sendRequest(UpdateModelEventName, model); + view.editModelPage.sendRequest(DeleteModelEventName, model); + } + testContext.deployModelService.verify(x => x.updateModel(model), TypeMoq.Times.atLeastOnce()); + testContext.deployModelService.verify(x => x.deleteModel(model), TypeMoq.Times.atLeastOnce()); + + should.notEqual(view, undefined); + }); }); diff --git a/extensions/machine-learning-services/src/test/views/models/editModelDialog.test.ts b/extensions/machine-learning-services/src/test/views/models/editModelDialog.test.ts new file mode 100644 index 0000000000..983895a3d9 --- /dev/null +++ b/extensions/machine-learning-services/src/test/views/models/editModelDialog.test.ts @@ -0,0 +1,33 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as should from 'should'; +import 'mocha'; +import { createContext } from './utils'; +import { ImportedModel } from '../../../modelManagement/interfaces'; +import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog'; + +describe('Edit Model Dialog', () => { + it('Should create view components successfully ', async function (): Promise { + let testContext = createContext(); + const model: ImportedModel = + { + id: 1, + modelName: 'name1', + description: 'desc1', + created: '2018-01-01', + version: '1.1', + table: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } + }; + let view = new EditModelDialog(testContext.apiWrapper.object, '', undefined, model); + view.open(); + + should.notEqual(view.dialogView, undefined); + }); +}); diff --git a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts index 817885971c..20bc31120d 100644 --- a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts @@ -12,7 +12,7 @@ import { ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType } from '../../../views/models/modelViewBase'; -import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces'; +import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces'; import { azureResource } from '../../../typings/azure-resource'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { ViewBase } from '../../../views/viewBase'; @@ -80,11 +80,10 @@ describe('Predict Wizard', () => { name: 'model' } ]; - let localModels: RegisteredModel[] = [ + let localModels: ImportedModel[] = [ { id: 1, - artifactName: 'model', - title: 'model', + modelName: 'model', table: { databaseName: 'db', tableName: 'tb', diff --git a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts index e8542183af..fae94d875e 100644 --- a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts @@ -8,7 +8,7 @@ import * as should from 'should'; import 'mocha'; import { createContext } from './utils'; import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase'; -import { RegisteredModel } from '../../../modelManagement/interfaces'; +import { ImportedModel } from '../../../modelManagement/interfaces'; import { azureResource } from '../../../typings/azure-resource'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { ViewBase } from '../../../views/viewBase'; @@ -80,11 +80,10 @@ describe('Register Model Wizard', () => { name: 'model' } ]; - let localModels: RegisteredModel[] = [ + let localModels: ImportedModel[] = [ { id: 1, - artifactName: 'model', - title: 'model', + modelName: 'model', table: { databaseName: 'db', tableName: 'tb', diff --git a/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts b/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts index 0a40b61de6..3409d6f929 100644 --- a/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts @@ -8,7 +8,7 @@ import 'mocha'; import { createContext } from './utils'; import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog'; import { ListModelsEventName } from '../../../views/models/modelViewBase'; -import { RegisteredModel } from '../../../modelManagement/interfaces'; +import { ImportedModel } from '../../../modelManagement/interfaces'; import { ViewBase } from '../../../views/viewBase'; describe('Registered Models Dialog', () => { @@ -27,11 +27,10 @@ describe('Registered Models Dialog', () => { let view = new ManageModelsDialog(testContext.apiWrapper.object, ''); view.open(); - let models: RegisteredModel[] = [ + let models: ImportedModel[] = [ { id: 1, - artifactName: 'model', - title: '', + modelName: 'model', table: { databaseName: 'db', tableName: 'tb', diff --git a/extensions/machine-learning-services/src/views/interfaces.ts b/extensions/machine-learning-services/src/views/interfaces.ts index bae2be5bd4..5c54a57d8b 100644 --- a/extensions/machine-learning-services/src/views/interfaces.ts +++ b/extensions/machine-learning-services/src/views/interfaces.ts @@ -34,4 +34,10 @@ export interface AzureModelResource extends AzureWorkspaceResource { model?: WorkspaceModel; } +export interface IComponentSettings { + multiSelect?: boolean; + editable?: boolean; + selectable?: boolean; +} + diff --git a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts index 8b36f31a2e..557981c738 100644 --- a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts @@ -10,11 +10,13 @@ import { AzureResourceFilterComponent } from './azureResourceFilterComponent'; import { AzureModelsTable } from './azureModelsTable'; import { IDataComponent, AzureModelResource } from '../interfaces'; import { ModelArtifact } from './prediction/modelArtifact'; +import { AzureSignInComponent } from './azureSignInComponent'; export class AzureModelsComponent extends ModelViewBase implements IDataComponent { public azureModelsTable: AzureModelsTable | undefined; public azureFilterComponent: AzureResourceFilterComponent | undefined; + public azureSignInComponent: AzureSignInComponent | undefined; private _loader: azdata.LoadingComponent | undefined; private _form: azdata.FormContainer | undefined; @@ -34,6 +36,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this); this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect); + this.azureSignInComponent = new AzureSignInComponent(this._apiWrapper, modelBuilder, this); this._loader = modelBuilder.loadingComponent() .withItem(this.azureModelsTable.component) .withProperties({ @@ -63,6 +66,20 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen } public addComponents(formBuilder: azdata.FormBuilder) { + this.removeComponents(formBuilder); + if (this.azureFilterComponent?.data?.account) { + this.addAzureComponents(formBuilder); + } else { + this.addAzureSignInComponents(formBuilder); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + this.removeAzureComponents(formBuilder); + this.removeAzureSignInComponents(formBuilder); + } + + private addAzureComponents(formBuilder: azdata.FormBuilder) { if (this.azureFilterComponent && this._loader) { this.azureFilterComponent.addComponents(formBuilder); @@ -73,7 +90,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen } } - public removeComponents(formBuilder: azdata.FormBuilder) { + private removeAzureComponents(formBuilder: azdata.FormBuilder) { if (this.azureFilterComponent && this._loader) { this.azureFilterComponent.removeComponents(formBuilder); formBuilder.removeFormItem({ @@ -83,6 +100,18 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen } } + private addAzureSignInComponents(formBuilder: azdata.FormBuilder) { + if (this.azureSignInComponent) { + this.azureSignInComponent.addComponents(formBuilder); + } + } + + private removeAzureSignInComponents(formBuilder: azdata.FormBuilder) { + if (this.azureSignInComponent) { + this.azureSignInComponent.removeComponents(formBuilder); + } + } + private async onLoading(): Promise { if (this._loader) { await this._loader.updateProperties({ loading: true }); diff --git a/extensions/machine-learning-services/src/views/models/azureSignInComponent.ts b/extensions/machine-learning-services/src/views/models/azureSignInComponent.ts new file mode 100644 index 0000000000..a54dc6d351 --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/azureSignInComponent.ts @@ -0,0 +1,69 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase, SignInToAzureEventName } from './modelViewBase'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as constants from '../../common/constants'; + +/** + * View to render filters to pick an azure resource + */ +const componentWidth = 300; +export class AzureSignInComponent extends ModelViewBase { + + private _form: azdata.FormContainer; + private _signInButton: azdata.ButtonComponent; + + /** + * Creates a new view + */ + constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) { + super(apiWrapper, parent.root, parent); + this._signInButton = this._modelBuilder.button().withProperties({ + width: componentWidth, + label: constants.azureSignIn, + }).component(); + this._signInButton.onDidClick(() => { + this.sendRequest(SignInToAzureEventName); + }); + + this._form = this._modelBuilder.formContainer().withFormItems([{ + title: constants.azureAccount, + component: this._signInButton + }]).component(); + } + + public addComponents(formBuilder: azdata.FormBuilder) { + if (this._signInButton) { + formBuilder.addFormItems([{ + title: constants.azureAccount, + component: this._signInButton + }]); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + if (this._signInButton) { + formBuilder.removeFormItem({ + title: constants.azureAccount, + component: this._signInButton + }); + } + } + + /** + * Returns the created component + */ + public get component(): azdata.Component { + return this._form; + } + + /** + * refreshes the view + */ + public async refresh(): Promise { + } +} diff --git a/extensions/machine-learning-services/src/views/models/manageModels/currentModelsComponent.ts b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsComponent.ts index 7dc339d7da..244136173f 100644 --- a/extensions/machine-learning-services/src/views/models/manageModels/currentModelsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsComponent.ts @@ -9,9 +9,9 @@ import * as constants from '../../../common/constants'; import { ModelViewBase } from '../modelViewBase'; import { CurrentModelsTable } from './currentModelsTable'; import { ApiWrapper } from '../../../common/apiWrapper'; -import { IPageView } from '../../interfaces'; +import { IPageView, IComponentSettings } from '../../interfaces'; import { TableSelectionComponent } from '../tableSelectionComponent'; -import { RegisteredModel } from '../../../modelManagement/interfaces'; +import { ImportedModel } from '../../../modelManagement/interfaces'; /** * View to render current registered models @@ -27,7 +27,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { * @param apiWrapper Creates new view * @param parent page parent */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = false) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) { super(apiWrapper, parent.root, parent); } @@ -41,7 +41,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { this._tableSelectionComponent.onSelectedChanged(async () => { await this.onTableSelected(); }); - this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect); + this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings); this._dataTable.registerComponent(modelBuilder); this._tableComponent = this._dataTable.component; @@ -102,7 +102,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { } } - public get data(): RegisteredModel[] | undefined { + public get data(): ImportedModel[] | undefined { return this._dataTable?.data; } diff --git a/extensions/machine-learning-services/src/views/models/manageModels/currentModelsTable.ts b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsTable.ts index 73c3a07de2..f9dfb22596 100644 --- a/extensions/machine-learning-services/src/views/models/manageModels/currentModelsTable.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsTable.ts @@ -6,20 +6,21 @@ import * as azdata from 'azdata'; import * as vscode from 'vscode'; import * as constants from '../../../common/constants'; -import { ModelViewBase } from '../modelViewBase'; +import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; -import { RegisteredModel } from '../../../modelManagement/interfaces'; -import { IDataComponent } from '../../interfaces'; +import { ImportedModel } from '../../../modelManagement/interfaces'; +import { IDataComponent, IComponentSettings } from '../../interfaces'; import { ModelArtifact } from '../prediction/modelArtifact'; +import * as utils from '../../../common/utils'; /** * View to render registered models table */ -export class CurrentModelsTable extends ModelViewBase implements IDataComponent { +export class CurrentModelsTable extends ModelViewBase implements IDataComponent { private _table: azdata.DeclarativeTableComponent | undefined; private _modelBuilder: azdata.ModelBuilder | undefined; - private _selectedModel: RegisteredModel[] = []; + private _selectedModel: ImportedModel[] = []; private _loader: azdata.LoadingComponent | undefined; private _downloadedFile: ModelArtifact | undefined; private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter(); @@ -28,7 +29,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< /** * Creates new view */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) { super(apiWrapper, parent.root, parent); } @@ -38,62 +39,66 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this._modelBuilder = modelBuilder; + let columns = [ + { // Name + displayName: constants.modelName, + ariaLabel: constants.modelName, + valueType: azdata.DeclarativeDataType.string, + isReadOnly: true, + width: 150, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Created + displayName: constants.modelCreated, + ariaLabel: constants.modelCreated, + valueType: azdata.DeclarativeDataType.string, + isReadOnly: true, + width: 150, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Action + displayName: '', + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + } + ]; + if (this._settings.editable) { + columns.push( + { // Action + displayName: '', + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + } + ); + } this._table = modelBuilder.declarativeTable() .withProperties( { - columns: [ - { // Artifact name - displayName: constants.modelArtifactName, - ariaLabel: constants.modelArtifactName, - valueType: azdata.DeclarativeDataType.string, - isReadOnly: true, - width: 150, - headerCssStyles: { - ...constants.cssStyles.tableHeader - }, - rowCssStyles: { - ...constants.cssStyles.tableRow - }, - }, - { // Name - displayName: constants.modelName, - ariaLabel: constants.modelName, - valueType: azdata.DeclarativeDataType.string, - isReadOnly: true, - width: 150, - headerCssStyles: { - ...constants.cssStyles.tableHeader - }, - rowCssStyles: { - ...constants.cssStyles.tableRow - }, - }, - { // Created - displayName: constants.modelCreated, - ariaLabel: constants.modelCreated, - valueType: azdata.DeclarativeDataType.string, - isReadOnly: true, - width: 150, - headerCssStyles: { - ...constants.cssStyles.tableHeader - }, - rowCssStyles: { - ...constants.cssStyles.tableRow - }, - }, - { // Action - displayName: '', - valueType: azdata.DeclarativeDataType.component, - isReadOnly: true, - width: 50, - headerCssStyles: { - ...constants.cssStyles.tableHeader - }, - rowCssStyles: { - ...constants.cssStyles.tableRow - }, - } - ], + columns: columns, data: [], ariaLabel: constants.mlsConfigTitle }) @@ -132,7 +137,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< public async loadData(): Promise { await this.onLoading(); if (this._table) { - let models: RegisteredModel[] | undefined; + let models: ImportedModel[] | undefined; if (this.importTable) { models = await this.listModels(this.importTable); @@ -163,11 +168,28 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< } } - private createTableRow(model: RegisteredModel): any[] { + private createTableRow(model: ImportedModel): any[] { + let row: any[] = [model.modelName, model.created]; if (this._modelBuilder) { - let selectModelButton: azdata.Component; + const selectButton = this.createSelectButton(model); + if (selectButton) { + row.push(selectButton); + } + const editButtons = this.createEditButtons(model); + if (editButtons && editButtons.length > 0) { + row = row.concat(editButtons); + } + } + + return row; + } + + private createSelectButton(model: ImportedModel): azdata.Component | undefined { + let selectModelButton: azdata.Component | undefined = undefined; + if (this._modelBuilder && this._settings.selectable) { + let onSelectItem = (checked: boolean) => { - if (!this._multiSelect) { + if (!this._settings.multiSelect) { this._selectedModel = []; } const foundItem = this._selectedModel.find(x => x === model); @@ -178,7 +200,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< } this.onModelSelected(); }; - if (this._multiSelect) { + if (this._settings.multiSelect) { const checkbox = this._modelBuilder.checkBox().withProperties({ name: 'amlModel', value: model.id, @@ -203,11 +225,53 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< }); selectModelButton = radioButton; } - - return [model.artifactName, model.title, model.created, selectModelButton]; } + return selectModelButton; + } - return []; + private createEditButtons(model: ImportedModel): azdata.Component[] | undefined { + let dropButton: azdata.ButtonComponent | undefined = undefined; + let editButton: azdata.ButtonComponent | undefined = undefined; + if (this._modelBuilder && this._settings.editable) { + dropButton = this._modelBuilder.button().withProperties({ + label: '', + title: constants.deleteTitle, + iconPath: { + dark: this.asAbsolutePath('images/dark/delete_inverse.svg'), + light: this.asAbsolutePath('images/light/delete.svg') + }, + width: 15, + height: 15 + }).component(); + dropButton.onDidClick(async () => { + try { + const confirm = await utils.promptConfirm(constants.confirmDeleteModel(model.modelName), this._apiWrapper); + if (confirm) { + await this.sendDataRequest(DeleteModelEventName, model); + if (this.parent) { + await this.parent?.refresh(); + } + } + } catch (error) { + this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`); + } + }); + + editButton = this._modelBuilder.button().withProperties({ + label: '', + title: constants.deleteTitle, + iconPath: { + dark: this.asAbsolutePath('images/dark/edit_inverse.svg'), + light: this.asAbsolutePath('images/light/edit.svg') + }, + width: 15, + height: 15 + }).component(); + editButton.onDidClick(async () => { + await this.sendDataRequest(EditModelEventName, model); + }); + } + return editButton && dropButton ? [editButton, dropButton] : undefined; } private async onModelSelected(): Promise { @@ -221,7 +285,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< /** * Returns selected data */ - public get data(): RegisteredModel[] | undefined { + public get data(): ImportedModel[] | undefined { return this._selectedModel; } diff --git a/extensions/machine-learning-services/src/views/models/manageModels/editModelDialog.ts b/extensions/machine-learning-services/src/views/models/manageModels/editModelDialog.ts new file mode 100644 index 0000000000..dbe795346d --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/manageModels/editModelDialog.ts @@ -0,0 +1,75 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import { ModelViewBase, UpdateModelEventName } from '../modelViewBase'; +import * as constants from '../../../common/constants'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import { DialogView } from '../../dialogView'; +import { ModelDetailsEditPage } from './modelDetailsEditPage'; +import { ImportedModel } from '../../../modelManagement/interfaces'; + +/** + * Dialog to render registered model views + */ +export class EditModelDialog extends ModelViewBase { + + constructor( + apiWrapper: ApiWrapper, + root: string, + private _parentView: ModelViewBase | undefined, + private _model: ImportedModel) { + super(apiWrapper, root); + this.dialogView = new DialogView(this._apiWrapper); + } + public dialogView: DialogView; + public editModelPage: ModelDetailsEditPage | undefined; + + /** + * Opens a dialog to edit models. + */ + public open(): void { + + this.editModelPage = new ModelDetailsEditPage(this._apiWrapper, this, this._model); + + let registerModelButton = this._apiWrapper.createButton(constants.extLangSaveButtonText); + registerModelButton.onClick(async () => { + if (this.editModelPage) { + const valid = await this.editModelPage.validate(); + if (valid) { + try { + await this.sendDataRequest(UpdateModelEventName, this.editModelPage?.data); + this.showInfoMessage(constants.modelUpdatedSuccessfully); + if (this._parentView) { + await this._parentView.refresh(); + } + } catch (error) { + this.showInfoMessage(`${constants.modelUpdateFailedError} ${constants.getErrorMessage(error)}`); + } + } + } + }); + + let dialog = this.dialogView.createDialog(constants.editModelTitle, [this.editModelPage]); + dialog.customButtons = [registerModelButton]; + this.mainViewPanel = dialog; + dialog.okButton.hidden = true; + dialog.cancelButton.label = constants.extLangDoneButtonText; + + dialog.registerCloseValidator(() => { + return false; // Blocks Enter key from closing dialog. + }); + + this._apiWrapper.openDialog(dialog); + } + + /** + * Resets the tabs for given provider Id + */ + public async refresh(): Promise { + if (this.dialogView) { + this.dialogView.refresh(); + } + } +} diff --git a/extensions/machine-learning-services/src/views/models/manageModels/importModelWizard.ts b/extensions/machine-learning-services/src/views/models/manageModels/importModelWizard.ts index 709fcd89bf..c388d9b356 100644 --- a/extensions/machine-learning-services/src/views/models/manageModels/importModelWizard.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/importModelWizard.ts @@ -14,7 +14,7 @@ import { WizardView } from '../../wizardView'; import { ModelSourcePage } from '../modelSourcePage'; import { ModelDetailsPage } from '../modelDetailsPage'; import { ModelBrowsePage } from '../modelBrowsePage'; -import { ModelImportLocationPage } from './modelmportLocationPage'; +import { ModelImportLocationPage } from './modelImportLocationPage'; /** * Wizard to register a model diff --git a/extensions/machine-learning-services/src/views/models/manageModels/manageModelsDialog.ts b/extensions/machine-learning-services/src/views/models/manageModels/manageModelsDialog.ts index 2dee324da2..7e16f55220 100644 --- a/extensions/machine-learning-services/src/views/models/manageModels/manageModelsDialog.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/manageModelsDialog.ts @@ -29,7 +29,10 @@ export class ManageModelsDialog extends ModelViewBase { */ public open(): void { - this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this); + this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this, { + editable: true, + selectable: false + }); let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle); registerModelButton.onClick(async () => { diff --git a/extensions/machine-learning-services/src/views/models/manageModels/modelDetailsComponent.ts b/extensions/machine-learning-services/src/views/models/manageModels/modelDetailsComponent.ts new file mode 100644 index 0000000000..eccc01449d --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/manageModels/modelDetailsComponent.ts @@ -0,0 +1,154 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import * as constants from '../../../common/constants'; +import { IDataComponent } from '../../interfaces'; +import { ImportedModel } from '../../../modelManagement/interfaces'; + +/** + * View to render filters to pick an azure resource + */ +export class ModelDetailsComponent extends ModelViewBase implements IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _nameComponent: azdata.InputBoxComponent | undefined; + private _descriptionComponent: azdata.InputBoxComponent | undefined; + private _createdComponent: azdata.Component | undefined; + private _deployedComponent: azdata.Component | undefined; + private _frameworkComponent: azdata.Component | undefined; + private _frameworkVersionComponent: azdata.Component | undefined; + /** + * Creates a new view + */ + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) { + super(apiWrapper, parent.root, parent); + } + + /** + * Register components + * @param modelBuilder model builder + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + this._createdComponent = modelBuilder.text().withProperties({ + value: this._model.created + }).component(); + this._deployedComponent = modelBuilder.text().withProperties({ + value: this._model.deploymentTime + }).component(); + this._frameworkComponent = modelBuilder.text().withProperties({ + value: this._model.framework + }).component(); + this._frameworkVersionComponent = modelBuilder.text().withProperties({ + value: this._model.frameworkVersion + }).component(); + this._nameComponent = modelBuilder.inputBox().withProperties({ + width: this.componentMaxLength, + value: this._model.modelName + }).component(); + this._descriptionComponent = modelBuilder.inputBox().withProperties({ + width: this.componentMaxLength, + value: this._model.description, + multiline: true, + height: 50 + }).component(); + + this._form = modelBuilder.formContainer().withFormItems([{ + title: '', + component: this._nameComponent + }, + { + title: '', + component: this._descriptionComponent + }]).component(); + return this._form; + } + + public addComponents(formBuilder: azdata.FormBuilder) { + if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) { + formBuilder.addFormItems([{ + title: constants.modelName, + component: this._nameComponent + }, { + title: constants.modelCreated, + component: this._createdComponent + }, + { + title: constants.modelDeployed, + component: this._deployedComponent + }, { + title: constants.modelFramework, + component: this._frameworkComponent + }, { + title: constants.modelFrameworkVersion, + component: this._frameworkVersionComponent + }, { + title: constants.modelDescription, + component: this._descriptionComponent + }]); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) { + formBuilder.removeFormItem({ + title: constants.modelCreated, + component: this._createdComponent + }); + formBuilder.removeFormItem({ + title: constants.modelCreated, + component: this._frameworkComponent + }); + formBuilder.removeFormItem({ + title: constants.modelCreated, + component: this._frameworkVersionComponent + }); + formBuilder.removeFormItem({ + title: constants.modelCreated, + component: this._deployedComponent + }); + formBuilder.removeFormItem({ + title: constants.modelName, + component: this._nameComponent + }); + formBuilder.removeFormItem({ + title: constants.modelDescription, + component: this._descriptionComponent + }); + } + } + + /** + * Returns the created component + */ + public get component(): azdata.Component | undefined { + return this._form; + } + + /** + * Returns selected data + */ + public get data(): ImportedModel | undefined { + let model = Object.assign({}, this._model); + model.modelName = this._nameComponent?.value || ''; + model.description = this._descriptionComponent?.value || ''; + return model; + } + + /** + * loads data in the components + */ + public async loadData(): Promise { + } + + /** + * refreshes the view + */ + public async refresh(): Promise { + await this.loadData(); + } +} diff --git a/extensions/machine-learning-services/src/views/models/manageModels/modelDetailsEditPage.ts b/extensions/machine-learning-services/src/views/models/manageModels/modelDetailsEditPage.ts new file mode 100644 index 0000000000..936f785ead --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/manageModels/modelDetailsEditPage.ts @@ -0,0 +1,85 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import * as constants from '../../../common/constants'; +import { IPageView, IDataComponent } from '../../interfaces'; +import { ImportedModel } from '../../../modelManagement/interfaces'; +import { ModelDetailsComponent } from './modelDetailsComponent'; + +/** + * View to pick model source + */ +export class ModelDetailsEditPage extends ModelViewBase implements IPageView, IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _formBuilder: azdata.FormBuilder | undefined; + public modelDetailsComponent: ModelDetailsComponent | undefined; + + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) { + super(apiWrapper, parent.root, parent); + } + + /** + * + * @param modelBuilder Register components + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + + this._formBuilder = modelBuilder.formContainer(); + this.modelDetailsComponent = new ModelDetailsComponent(this._apiWrapper, this, this._model); + + this.modelDetailsComponent.registerComponent(modelBuilder); + this.modelDetailsComponent.addComponents(this._formBuilder); + this._form = this._formBuilder.component(); + return this._form; + } + + /** + * Returns selected data + */ + public get data(): ImportedModel | undefined { + return this.modelDetailsComponent?.data; + } + + /** + * Returns the component + */ + public get component(): azdata.Component | undefined { + return this._form; + } + + /** + * Refreshes the view + */ + public async refresh(): Promise { + if (this.modelDetailsComponent) { + await this.modelDetailsComponent.refresh(); + } + } + + /** + * Returns page title + */ + public get title(): string { + return constants.modelImportTargetPageTitle; + } + + public async disposePage(): Promise { + } + + public async validate(): Promise { + let validated = false; + + if (this.data?.modelName) { + validated = true; + } else { + this.showErrorMessage(constants.modelNameRequiredError); + } + return validated; + } +} diff --git a/extensions/machine-learning-services/src/views/models/manageModels/modelmportLocationPage.ts b/extensions/machine-learning-services/src/views/models/manageModels/modelImportLocationPage.ts similarity index 97% rename from extensions/machine-learning-services/src/views/models/manageModels/modelmportLocationPage.ts rename to extensions/machine-learning-services/src/views/models/manageModels/modelImportLocationPage.ts index 8f8098b360..740d9af157 100644 --- a/extensions/machine-learning-services/src/views/models/manageModels/modelmportLocationPage.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/modelImportLocationPage.ts @@ -44,7 +44,6 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, private async onTableSelected(): Promise { if (this.tableSelectionComponent?.data) { this.importTable = this.tableSelectionComponent?.data; - //this.sendRequest(StoreImportTableEventName, this.importTable); } } diff --git a/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts b/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts index 146f1b1281..90bc74b0d7 100644 --- a/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts +++ b/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts @@ -19,7 +19,7 @@ import { CurrentModelsComponent } from './manageModels/currentModelsComponent'; export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent { private _form: azdata.FormContainer | undefined; - private _title: string = constants.modelSourcePageTitle; + private _title: string = constants.localModelPageTitle; private _formBuilder: azdata.FormBuilder | undefined; public localModelsComponent: LocalModelsComponent | undefined; public azureModelsComponent: AzureModelsComponent | undefined; @@ -40,7 +40,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo this.localModelsComponent.registerComponent(modelBuilder); this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect); this.azureModelsComponent.registerComponent(modelBuilder); - this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, this._multiSelect); + this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, { + selectable: true, + multiSelect: this._multiSelect, + editable: false + }); this.registeredModelsComponent.registerComponent(modelBuilder); this.refresh(); this._form = this._formBuilder.component(); @@ -96,12 +100,12 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo private loadTitle(): void { if (this.modelSourceType === ModelSourceType.Local) { - this._title = 'Upload model file'; + this._title = constants.localModelPageTitle; } else if (this.modelSourceType === ModelSourceType.Azure) { - this._title = 'Import from Azure Machine Learning'; + this._title = constants.azureModelPageTitle; } else if (this.modelSourceType === ModelSourceType.RegisteredModels) { - this._title = 'Select imported model'; + this._title = constants.importedModelsPageTitle; } else { this._title = constants.modelSourcePageTitle; } @@ -111,6 +115,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo * Returns page title */ public get title(): string { + this.loadTitle(); return this._title; } @@ -144,7 +149,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo return { modelData: x, modelDetails: { - title: fileName, + modelName: fileName, fileName: fileName }, targetImportTable: this.importTable @@ -164,8 +169,11 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo model: x.model }, modelDetails: { - title: x.model?.name || '', - fileName: x.model?.name + modelName: x.model?.name || '', + fileName: x.model?.name, + framework: x.model?.framework, + frameworkVersion: x.model?.frameworkVersion, + created: x.model?.createdTime }, targetImportTable: this.importTable }; @@ -178,7 +186,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo return { modelData: x, modelDetails: { - title: '' + modelName: '' }, targetImportTable: this.importTable }; diff --git a/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts b/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts index 1e8fbb9779..4efcc6fac4 100644 --- a/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts +++ b/extensions/machine-learning-services/src/views/models/modelDetailsPage.ts @@ -8,7 +8,7 @@ import { ModelViewBase, ModelViewData } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; import * as constants from '../../common/constants'; import { IPageView, IDataComponent } from '../interfaces'; -import { ModelDetailsComponent } from './modelDetailsComponent'; +import { ModelsDetailsTableComponent } from './modelsDetailsTableComponent'; /** * View to pick model details @@ -17,7 +17,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC private _form: azdata.FormContainer | undefined; private _formBuilder: azdata.FormBuilder | undefined; - public modelDetails: ModelDetailsComponent | undefined; + public modelDetails: ModelsDetailsTableComponent | undefined; constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { super(apiWrapper, parent.root, parent); @@ -30,7 +30,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this._formBuilder = modelBuilder.formContainer(); - this.modelDetails = new ModelDetailsComponent(this._apiWrapper, modelBuilder, this); + this.modelDetails = new ModelsDetailsTableComponent(this._apiWrapper, modelBuilder, this); this.modelDetails.registerComponent(modelBuilder); this.modelDetails.addComponents(this._formBuilder); this.refresh(); @@ -73,7 +73,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC } public validate(): Promise { - if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.title)) { + if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.modelName)) { return Promise.resolve(true); } else { this.showErrorMessage(constants.modelNameRequiredError); diff --git a/extensions/machine-learning-services/src/views/models/modelManagementController.ts b/extensions/machine-learning-services/src/views/models/modelManagementController.ts index b6b60b5fd4..e0925477a0 100644 --- a/extensions/machine-learning-services/src/views/models/modelManagementController.ts +++ b/extensions/machine-learning-services/src/views/models/modelManagementController.ts @@ -9,7 +9,7 @@ import { azureResource } from '../../typings/azure-resource'; import { ApiWrapper } from '../../common/apiWrapper'; import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; -import { RegisteredModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces'; +import { ImportedModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { DeployedModelService } from '../../modelManagement/deployedModelService'; import { ManageModelsDialog } from './manageModels/manageModelsDialog'; @@ -17,7 +17,7 @@ import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, - ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName + ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName } from './modelViewBase'; import { ControllerBase } from '../controllerBase'; import { ImportModelWizard } from './manageModels/importModelWizard'; @@ -26,6 +26,7 @@ import * as constants from '../../common/constants'; import { PredictWizard } from './prediction/predictWizard'; import { AzureModelResource } from '../interfaces'; import { PredictService } from '../../prediction/predictService'; +import { EditModelDialog } from './manageModels/editModelDialog'; /** * Model management UI controller @@ -71,6 +72,24 @@ export class ModelManagementController extends ControllerBase { return view; } + /** + * Opens the dialog to edit model + */ + public async editModel(model: ImportedModel, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise { + controller = controller || this; + apiWrapper = apiWrapper || this._apiWrapper; + root = root || this._root; + let view = new EditModelDialog(apiWrapper, root, parent, model); + + controller.registerEvents(view); + + // Open view + // + await view.open(); + await view.refresh(); + return view; + } + /** * Opens the wizard for prediction */ @@ -136,6 +155,18 @@ export class ModelManagementController extends ControllerBase { const importTable = args; await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root); }); + view.on(EditModelEventName, async (args) => { + const model = args; + await this.executeAction(view, EditModelEventName, this.editModel, model, view, this, this._apiWrapper, this._root); + }); + view.on(UpdateModelEventName, async (args) => { + const model = args; + await this.executeAction(view, UpdateModelEventName, this.updateModel, this._registeredModelService, model); + }); + view.on(DeleteModelEventName, async (args) => { + const model = args; + await this.executeAction(view, DeleteModelEventName, this.deleteModel, this._registeredModelService, model); + }); view.on(RegisterAzureModelEventName, async (arg) => { let models = arg; await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService, @@ -164,7 +195,7 @@ export class ModelManagementController extends ControllerBase { predictArgs, predictArgs.model, predictArgs.filePath); }); view.on(DownloadRegisteredModelEventName, async (arg) => { - let model = arg; + let model = arg; await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService, model); }); @@ -178,9 +209,13 @@ export class ModelManagementController extends ControllerBase { await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService, importTable); }); - view.on(SourceModelSelectedEventName, (arg) => { + view.on(SourceModelSelectedEventName, async (arg) => { view.modelSourceType = arg; - view.refresh(); + await view.refresh(); + }); + view.on(SignInToAzureEventName, async () => { + await this.executeAction(view, SignInToAzureEventName, this.signInToAzure, this._amlService); + await view.refresh(); }); } @@ -206,6 +241,10 @@ export class ModelManagementController extends ControllerBase { return view; } + private async signInToAzure(service: AzureModelRegistryService): Promise { + return await service.signInToAzure(); + } + private async getAzureAccounts(service: AzureModelRegistryService): Promise { return await service.getAccounts(); } @@ -225,7 +264,7 @@ export class ModelManagementController extends ControllerBase { return await service.getWorkspaces(account, subscription, group); } - private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise { + private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise { return registeredModelService.getDeployedModels(table); } @@ -258,6 +297,22 @@ export class ModelManagementController extends ControllerBase { } } + private async updateModel(service: DeployedModelService, model: ImportedModel | undefined): Promise { + if (model) { + await service.updateModel(model); + } else { + throw Error(constants.invalidModelToRegisterError); + } + } + + private async deleteModel(service: DeployedModelService, model: ImportedModel | undefined): Promise { + if (model) { + await service.deleteModel(model); + } else { + throw Error(constants.invalidModelToRegisterError); + } + } + private async registerAzureModel( azureService: AzureModelRegistryService, service: DeployedModelService, @@ -306,7 +361,7 @@ export class ModelManagementController extends ControllerBase { private async generatePredictScript( predictService: PredictService, predictParams: PredictParameters, - registeredModel: RegisteredModel | undefined, + registeredModel: ImportedModel | undefined, filePath: string | undefined ): Promise { if (!predictParams) { @@ -334,7 +389,7 @@ export class ModelManagementController extends ControllerBase { private async downloadRegisteredModel( registeredModelService: DeployedModelService, - model: RegisteredModel | undefined): Promise { + model: ImportedModel | undefined): Promise { if (!model) { throw Error(constants.invalidModelToPredictError); } diff --git a/extensions/machine-learning-services/src/views/models/modelViewBase.ts b/extensions/machine-learning-services/src/views/models/modelViewBase.ts index 65ddc53270..0330d41925 100644 --- a/extensions/machine-learning-services/src/views/models/modelViewBase.ts +++ b/extensions/machine-learning-services/src/views/models/modelViewBase.ts @@ -8,7 +8,7 @@ import * as azdata from 'azdata'; import { azureResource } from '../../typings/azure-resource'; import { ApiWrapper } from '../../common/apiWrapper'; import { ViewBase } from '../viewBase'; -import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces'; +import { ImportedModel, WorkspaceModel, ImportedModelDetails, ModelParameters } from '../../modelManagement/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { AzureWorkspaceResource, AzureModelResource } from '../interfaces'; @@ -18,11 +18,11 @@ export interface AzureResourceEventArgs extends AzureWorkspaceResource { } export interface RegisterModelEventArgs extends AzureWorkspaceResource { - details?: RegisteredModelDetails + details?: ImportedModelDetails } export interface PredictModelEventArgs extends PredictParameters { - model?: RegisteredModel; + model?: ImportedModel; filePath?: string; } @@ -35,8 +35,8 @@ export enum ModelSourceType { export interface ModelViewData { modelFile?: string; - modelData: AzureModelResource | string | RegisteredModel; - modelDetails?: RegisteredModelDetails; + modelData: AzureModelResource | string | ImportedModel; + modelDetails?: ImportedModelDetails; targetImportTable?: DatabaseTable; } @@ -57,10 +57,14 @@ export const DownloadAzureModelEventName = 'downloadAzureLocalModel'; export const DownloadRegisteredModelEventName = 'downloadRegisteredModel'; export const PredictModelEventName = 'predictModel'; export const RegisterModelEventName = 'registerModel'; +export const EditModelEventName = 'editModel'; +export const UpdateModelEventName = 'updateModel'; +export const DeleteModelEventName = 'deleteModel'; export const SourceModelSelectedEventName = 'sourceModelSelected'; export const LoadModelParametersEventName = 'loadModelParameters'; export const StoreImportTableEventName = 'storeImportTable'; export const VerifyImportTableEventName = 'verifyImportTable'; +export const SignInToAzureEventName = 'signInToAzure'; /** * Base class for all model management views @@ -94,7 +98,11 @@ export abstract class ModelViewBase extends ViewBase { DownloadRegisteredModelEventName, LoadModelParametersEventName, StoreImportTableEventName, - VerifyImportTableEventName]); + VerifyImportTableEventName, + EditModelEventName, + UpdateModelEventName, + DeleteModelEventName, + SignInToAzureEventName]); } /** @@ -115,7 +123,7 @@ export abstract class ModelViewBase extends ViewBase { /** * list registered models */ - public async listModels(table: DatabaseTable): Promise { + public async listModels(table: DatabaseTable): Promise { return await this.sendDataRequest(ListModelsEventName, table); } @@ -170,7 +178,7 @@ export abstract class ModelViewBase extends ViewBase { * downloads registered model * @param model model to download */ - public async downloadRegisteredModel(model: RegisteredModel | undefined): Promise { + public async downloadRegisteredModel(model: ImportedModel | undefined): Promise { return await this.sendDataRequest(DownloadRegisteredModelEventName, model); } @@ -215,7 +223,7 @@ export abstract class ModelViewBase extends ViewBase { * registers azure model * @param args azure resource */ - public async generatePredictScript(model: RegisteredModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise { + public async generatePredictScript(model: ImportedModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise { const args: PredictModelEventArgs = Object.assign({}, params, { model: model, filePath: filePath, diff --git a/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts b/extensions/machine-learning-services/src/views/models/modelsDetailsTableComponent.ts similarity index 95% rename from extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts rename to extensions/machine-learning-services/src/views/models/modelsDetailsTableComponent.ts index 9c9e0b1de5..8bc3a318f2 100644 --- a/extensions/machine-learning-services/src/views/models/modelDetailsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/modelsDetailsTableComponent.ts @@ -12,7 +12,7 @@ import { IDataComponent } from '../interfaces'; /** * View to pick local models file */ -export class ModelDetailsComponent extends ModelViewBase implements IDataComponent { +export class ModelsDetailsTableComponent extends ModelViewBase implements IDataComponent { private _table: azdata.DeclarativeTableComponent | undefined; /** @@ -127,7 +127,7 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone private createTableRow(model: ModelViewData | undefined): any[] { if (this._modelBuilder && model && model.modelDetails) { const nameComponent = this._modelBuilder.inputBox().withProperties({ - value: model.modelDetails.title, + value: model.modelDetails.modelName, width: this.componentMaxLength - 100, required: true }).component(); @@ -142,7 +142,7 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone }); nameComponent.onTextChanged(() => { if (model.modelDetails) { - model.modelDetails.title = nameComponent.value || ''; + model.modelDetails.modelName = nameComponent.value || ''; } }); let deleteButton = this._modelBuilder.button().withProperties({ diff --git a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts index 572c046982..82c0e8b9a1 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts @@ -13,7 +13,7 @@ import * as constants from '../../../common/constants'; import { WizardView } from '../../wizardView'; import { ModelSourcePage } from '../modelSourcePage'; import { ColumnsSelectionPage } from './columnsSelectionPage'; -import { RegisteredModel } from '../../../modelManagement/interfaces'; +import { ImportedModel } from '../../../modelManagement/interfaces'; import { ModelArtifact } from './modelArtifact'; import { ModelBrowsePage } from '../modelBrowsePage'; @@ -124,7 +124,7 @@ export class PredictWizard extends ModelViewBase { private async predict(): Promise { try { let modelFilePath: string | undefined; - let registeredModel: RegisteredModel | undefined = undefined; + let registeredModel: ImportedModel | undefined = undefined; if (this.modelResources && this.modelResources.data && this.modelResources.data === ModelSourceType.RegisteredModels && this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) { const data = this.modelBrowsePage?.registeredModelsComponent?.data; diff --git a/extensions/machine-learning-services/src/views/models/tableSelectionComponent.ts b/extensions/machine-learning-services/src/views/models/tableSelectionComponent.ts index 537c687f2e..8259427bc9 100644 --- a/extensions/machine-learning-services/src/views/models/tableSelectionComponent.ts +++ b/extensions/machine-learning-services/src/views/models/tableSelectionComponent.ts @@ -168,7 +168,7 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo this._selectedTableName = this.getTableFullName(selectedTable); this._tables.value = this.getTableFullName(selectedTable); } else { - this._selectedTableName = this.getTableFullName(this._tableNames[0]); + this._selectedTableName = this._editable ? this.getTableFullName(this.importTable) : this.getTableFullName(this._tableNames[0]); } } else { this._selectedTableName = this.getTableFullName(this._tableNames[0]); diff --git a/extensions/machine-learning-services/src/views/wizardView.ts b/extensions/machine-learning-services/src/views/wizardView.ts index 3067191cdc..c25a682180 100644 --- a/extensions/machine-learning-services/src/views/wizardView.ts +++ b/extensions/machine-learning-services/src/views/wizardView.ts @@ -35,13 +35,14 @@ export class WizardView extends MainViewBase { */ public addWizardPage(page: IPageView, index: number): void { if (this._wizard) { - this.addPage(page, index); - this._wizard.removePage(index); - if (!page.viewPanel) { + const currentPage = this._wizard.currentPage; + if (page && currentPage < index) { + this.addPage(page, index); + this._wizard.removePage(index); this.createWizardPage(page.title || '', page); + this._wizard.addPage(page.viewPanel, index); + this._wizard.setCurrentPage(currentPage); } - this._wizard.addPage(page.viewPanel, index); - this._wizard.setCurrentPage(index); } } @@ -109,4 +110,14 @@ export class WizardView extends MainViewBase { public get wizard(): azdata.window.Wizard | undefined { return this._wizard; } + + public async refresh(): Promise { + for (let index = 0; index < this._pages.length; index++) { + const page = this._pages[index]; + if (this._wizard?.pages[index]?.title !== page.title) { + this.addWizardPage(page, index); + } + } + await super.refresh(); + } }