diff --git a/extensions/machine-learning-services/src/common/constants.ts b/extensions/machine-learning-services/src/common/constants.ts index 6a580d6559..e43283c6bc 100644 --- a/extensions/machine-learning-services/src/common/constants.ts +++ b/extensions/machine-learning-services/src/common/constants.ts @@ -115,6 +115,8 @@ export const extLangInstallFailedError = localize('extLang.installFailedError', export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language"); export const modelArtifactName = localize('models.artifactName', "Artifact Name"); +export const databaseName = localize('databaseName', "Database name"); +export const tableName = localize('tableName', "Table name"); export const modelName = localize('models.name', "Name"); export const modelFileName = localize('models.fileName', "File"); export const modelDescription = localize('models.description', "Description"); @@ -140,13 +142,14 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model export const localModelsTitle = localize('models.localModelsTitle', "Local models"); export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location"); export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?"); +export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Where do you want import models to?"); export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input"); export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details"); export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file"); export const currentModelsTitle = localize('models.currentModelsTitle', "Models"); export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy"); export const predictModel = localize('models.predictModel', "Predict"); -export const registerModelTitle = localize('models.RegisterWizard', "Deployed models"); +export const registerModelTitle = localize('models.RegisterWizard', "Import models"); export const importModelTitle = localize('models.importModelTitle', "Import models"); export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model"); export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions"); @@ -163,9 +166,12 @@ export const invalidAzureResourceError = localize('models.invalidAzureResourceEr export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register"); export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict"); export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model"); +export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table"); export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); } +export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); } +export function invalidImportTableSchemaError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableSchemaError', "Table schema is not supported for model import. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); } export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'"); export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported"); diff --git a/extensions/machine-learning-services/src/common/queryRunner.ts b/extensions/machine-learning-services/src/common/queryRunner.ts index ec46cee7d7..18b3922c23 100644 --- a/extensions/machine-learning-services/src/common/queryRunner.ts +++ b/extensions/machine-learning-services/src/common/queryRunner.ts @@ -183,10 +183,31 @@ export class QueryRunner { try { return await this.runQuery(connection, query); } catch (error) { - console.log(error); + //console.log(error); return undefined; } } + + /** + * Executes the query but doesn't fail it is fails + * @param connection SQL connection + * @param query query to run + */ + public async runWithDatabaseChange(connection: azdata.connection.ConnectionProfile, query: string, queryDb: string): Promise { + if (connection) { + try { + return await this.runQuery(connection, ` + USE [${utils.doubleEscapeSingleBrackets(queryDb)}] + ${query}`); + } catch (error) { + console.log(error); + } + finally { + this.safeRunQuery(connection, `USE [${utils.doubleEscapeSingleBrackets(connection.databaseName || 'master')}]`); + } + } + return undefined; + } } diff --git a/extensions/machine-learning-services/src/common/utils.ts b/extensions/machine-learning-services/src/common/utils.ts index aac4966ada..892aad9f30 100644 --- a/extensions/machine-learning-services/src/common/utils.ts +++ b/extensions/machine-learning-services/src/common/utils.ts @@ -11,7 +11,6 @@ import * as fs from 'fs'; import * as constants from '../common/constants'; import { promisify } from 'util'; import { ApiWrapper } from './apiWrapper'; -import { Config } from '../configurations/config'; export async function execCommandOnTempFile(content: string, command: (filePath: string) => Promise): Promise { let tempFilePath: string = ''; @@ -221,21 +220,21 @@ export function getScriptWithDBChange(currentDb: string, databaseName: string, s * Returns full name of model registration table * @param config config */ -export function getRegisteredModelsThreePartsName(config: Config) { - const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName); - const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName); - const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName); - return `[${dbName}].[${schema}].[${tableName}]`; +export function getRegisteredModelsThreePartsName(db: string, table: string, schema: string) { + const dbName = doubleEscapeSingleBrackets(db); + const schemaName = doubleEscapeSingleBrackets(schema); + const tableName = doubleEscapeSingleBrackets(table); + return `[${dbName}].[${schemaName}].[${tableName}]`; } /** * Returns full name of model registration table * @param config config object */ -export function getRegisteredModelsTowPartsName(config: Config) { - const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName); - const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName); - return `[${schema}].[${tableName}]`; +export function getRegisteredModelsTwoPartsName(table: string, schema: string) { + const schemaName = doubleEscapeSingleBrackets(schema); + const tableName = doubleEscapeSingleBrackets(table); + return `[${schemaName}].[${tableName}]`; } /** diff --git a/extensions/machine-learning-services/src/controllers/mainController.ts b/extensions/machine-learning-services/src/controllers/mainController.ts index 005c600582..e32ec99e2f 100644 --- a/extensions/machine-learning-services/src/controllers/mainController.ts +++ b/extensions/machine-learning-services/src/controllers/mainController.ts @@ -23,6 +23,7 @@ import { AzureModelRegistryService } from '../modelManagement/azureModelRegistry import { ModelPythonClient } from '../modelManagement/modelPythonClient'; import { PredictService } from '../prediction/predictService'; import { DashboardWidget } from '../views/widgets/dashboardWidget'; +import { ModelConfigRecent } from '../modelManagement/modelConfigRecent'; /** * The main controller class that initializes the extension @@ -102,12 +103,13 @@ export default class MainController implements vscode.Disposable { let languagesModel = new LanguageService(this._apiWrapper, mssqlService); let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel); let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager); + let modelRecentService = new ModelConfigRecent(this._context.globalState); // Model Management // - let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter); + let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter, modelRecentService); let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel); - let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config); + let predictService = new PredictService(this._apiWrapper, this._queryRunner); let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, azureModelsService, registeredModelService, predictService); @@ -121,7 +123,7 @@ export default class MainController implements vscode.Disposable { await modelManagementController.manageRegisteredModels(); })); this._apiWrapper.registerCommand(constants.mlImportModelCommand, (async () => { - await modelManagementController.registerModel(); + await modelManagementController.registerModel(undefined); })); this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => { await modelManagementController.predictModel(); @@ -135,15 +137,6 @@ export default class MainController implements vscode.Disposable { this._apiWrapper.registerTaskHandler(constants.mlManageLanguagesCommand, async () => { await languageController.manageLanguages(); }); - this._apiWrapper.registerTaskHandler(constants.mlManageModelsCommand, async () => { - await modelManagementController.manageRegisteredModels(); - }); - this._apiWrapper.registerTaskHandler(constants.mlImportModelCommand, async () => { - await modelManagementController.registerModel(); - }); - this._apiWrapper.registerTaskHandler(constants.mlsPredictModelCommand, async () => { - await modelManagementController.predictModel(); - }); } /** diff --git a/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts index b3ac7161db..13c3542f9a 100644 --- a/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts +++ b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts @@ -12,6 +12,8 @@ import { QueryRunner } from '../common/queryRunner'; import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces'; import { ModelPythonClient } from './modelPythonClient'; import * as constants from '../common/constants'; +import { DatabaseTable } from '../prediction/interfaces'; +import { ModelConfigRecent } from './modelConfigRecent'; /** * Service to deployed models @@ -25,23 +27,25 @@ export class DeployedModelService { private _apiWrapper: ApiWrapper, private _config: Config, private _queryRunner: QueryRunner, - private _modelClient: ModelPythonClient) { + private _modelClient: ModelPythonClient, + private _recentModelService: ModelConfigRecent) { } /** * Returns deployed models */ - public async getDeployedModels(): Promise { + public async getDeployedModels(table: DatabaseTable): Promise { let connection = await this.getCurrentConnection(); let list: RegisteredModel[] = []; + if (!table.databaseName || !table.tableName || !table.schema) { + return []; + } if (connection) { - let query = this.getConfigureQuery(connection.databaseName); - await this._queryRunner.safeRunQuery(connection, query); - query = this.getDeployedModelsQuery(); + const query = this.getDeployedModelsQuery(table); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { result.rows.forEach(row => { - list.push(this.loadModelData(row)); + list.push(this.loadModelData(row, table)); }); } } else { @@ -82,10 +86,13 @@ export class DeployedModelService { * @param filePath model file path * @param details model details */ - public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined) { + public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined, table: DatabaseTable) { let connection = await this.getCurrentConnection(); - if (connection) { - let currentModels = await this.getDeployedModels(); + if (connection && table.databaseName) { + + await this.configureImport(connection, table); + + let currentModels = await this.getDeployedModels(table); const content = await utils.readFileInHex(filePath); const fileName = details?.fileName || utils.getFileName(filePath); let modelToAdd: RegisteredModel = { @@ -94,25 +101,92 @@ export class DeployedModelService { content: content, title: details?.title || fileName, description: details?.description, - version: details?.version + version: details?.version, + table: table }; - await this._queryRunner.safeRunQuery(connection, this.getInsertModelQuery(connection.databaseName, modelToAdd)); + await this._queryRunner.runWithDatabaseChange(connection, this.getInsertModelQuery(modelToAdd, table), table.databaseName); - let updatedModels = await this.getDeployedModels(); + let updatedModels = await this.getDeployedModels(table); if (updatedModels.length < currentModels.length + 1) { throw Error(constants.importModelFailedError(details?.title, filePath)); } + } else { + throw new Error(constants.noConnectionError); } } - private loadModelData(row: azdata.DbCellValue[]): RegisteredModel { + + public async configureImport(connection: azdata.connection.ConnectionProfile, table: DatabaseTable) { + if (connection && table.databaseName) { + let query = this.getDatabaseConfigureQuery(table); + await this._queryRunner.safeRunQuery(connection, query); + + query = this.getConfigureTableQuery(table); + await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName); + } + } + + /** + * Verifies if the given table name is valid to be used as import table. If table doesn't exist returns true to create new table + * Otherwise verifies the schema and returns true if the schema is supported + * @param connection database connection + * @param table config table name + */ + public async verifyConfigTable(table: DatabaseTable): Promise { + let connection = await this.getCurrentConnection(); + if (connection && table.databaseName) { + let databases = await this._apiWrapper.listDatabases(connection.connectionId); + + // If database exist verify the table schema + // + if ((await databases).find(x => x === table.databaseName)) { + const query = this.getConfigTableVerificationQuery(table); + const result = await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName); + return result !== undefined && result.rows.length > 0 && result.rows[0][0].displayValue === '1'; + } else { + return true; + } + } else { + throw new Error(constants.noConnectionError); + } + } + + public async getRecentImportTable(): Promise { + let connection = await this.getCurrentConnection(); + let table: DatabaseTable | undefined; + if (connection) { + table = this._recentModelService.getModelTable(connection); + if (!table) { + table = { + databaseName: connection.databaseName ?? 'master', + tableName: this._config.registeredModelTableName, + schema: this._config.registeredModelTableSchemaName + }; + } + } else { + throw new Error(constants.noConnectionError); + } + return table; + } + + public async storeRecentImportTable(importTable: DatabaseTable): Promise { + let connection = await this.getCurrentConnection(); + if (connection) { + this._recentModelService.storeModelTable(connection, importTable); + } else { + throw new Error(constants.noConnectionError); + } + } + + private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): RegisteredModel { return { id: +row[0].displayValue, artifactName: row[1].displayValue, title: row[2].displayValue, description: row[3].displayValue, version: row[4].displayValue, - created: row[5].displayValue + created: row[5].displayValue, + table: table }; } @@ -120,87 +194,138 @@ export class DeployedModelService { return await this._apiWrapper.getCurrentConnection(); } - public getConfigureQuery(currentDatabaseName: string): string { - return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.getConfigureTableQuery()); + public getDatabaseConfigureQuery(configTable: DatabaseTable): string { + return ` + IF NOT EXISTS ( + SELECT name + FROM sys.databases + WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}' + ) + CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}] + `; } - public getDeployedModelsQuery(): string { + public getDeployedModelsQuery(table: DatabaseTable): string { return ` SELECT artifact_id, artifact_name, name, description, version, created - FROM ${utils.getRegisteredModelsThreePartsName(this._config)} + FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')} WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml' Order by artifact_id `; } + /** + * Verifies config table has the expected schema + * @param databaseName + * @param tableName + */ + public getConfigTableVerificationQuery(table: DatabaseTable): string { + let tableName = table.tableName; + let schemaName = table.schema; + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); + + return ` + IF NOT EXISTS ( + SELECT name + FROM sys.databases + WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}' + ) + BEGIN + Select 1 + END + ELSE + BEGIN + USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}] + IF EXISTS + ( SELECT t.name, s.name + FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id + WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}' + AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}' + ) + BEGIN + IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version') + AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created') + BEGIN + Select 1 + END + ELSE + BEGIN + Select 0 + END + END + ELSE + select 1 + END + `; + } + /** * Update the table and adds extra columns (name, description, version) if doesn't already exist. * Note: this code is temporary and will be removed weh the table supports the required schema * @param databaseName * @param tableName */ - public getConfigureTableQuery(): string { - let databaseName = this._config.registeredModelDatabaseName; - let tableName = this._config.registeredModelTableName; - let schemaName = this._config.registeredModelTableSchemaName; + public getConfigureTableQuery(table: DatabaseTable): string { + let tableName = table.tableName; + let schemaName = table.schema; + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); return ` - IF NOT EXISTS ( - SELECT [name] - FROM sys.databases - WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}' - ) - CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}] - GO - USE [${utils.doubleEscapeSingleBrackets(databaseName)}] IF EXISTS - ( SELECT [t.name], [s.name] + ( SELECT t.name, s.name FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id - WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}' - AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}' + WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}' + AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}' ) BEGIN - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name') - ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version') - ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created') + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name') + ALTER TABLE ${twoPartTableName} ADD [artifact_name] [varchar](256) NOT NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content') + ALTER TABLE ${twoPartTableName} ADD [artifact_content] [varbinary](max) NOT NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name') + ALTER TABLE ${twoPartTableName} ADD [name] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version') + ALTER TABLE ${twoPartTableName} ADD [version] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created') BEGIN - ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL - ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created + ALTER TABLE ${twoPartTableName} ADD [created] [datetime] NULL + ALTER TABLE ${twoPartTableName} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created END - IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description') - ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='description') + ALTER TABLE ${twoPartTableName} ADD [description] [varchar](256) NULL END Else BEGIN - CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}( + CREATE TABLE ${twoPartTableName}( [artifact_id] [int] IDENTITY(1,1) NOT NULL, [artifact_name] [varchar](256) NOT NULL, - [group_path] [varchar](256) NULL, [artifact_content] [varbinary](max) NOT NULL, [artifact_initial_size] [bigint] NULL, [name] [varchar](256) NULL, [version] [varchar](256) NULL, [created] [datetime] NULL, [description] [varchar](256) NULL, - CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED + CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_artifact_pk] PRIMARY KEY CLUSTERED ( [artifact_id] ASC )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] - ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] + ALTER TABLE [dbo].[${utils.doubleEscapeSingleBrackets(tableName)}] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] END `; } - public getInsertModelQuery(currentDatabaseName: string, model: RegisteredModel): string { + public getInsertModelQuery(model: RegisteredModel, table: DatabaseTable): string { + const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || ''); + const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || ''); let updateScript = ` - Insert into ${utils.getRegisteredModelsTowPartsName(this._config)} - (artifact_name, group_path, artifact_content, name, version, description) + Insert into ${twoPartTableName} + (artifact_name, artifact_content, name, version, description) values ( '${utils.doubleEscapeSingleQuotes(model.artifactName || '')}', - 'ADS', ${utils.doubleEscapeSingleQuotes(model.content || '')}, '${utils.doubleEscapeSingleQuotes(model.title || '')}', '${utils.doubleEscapeSingleQuotes(model.version || '')}', @@ -208,17 +333,19 @@ export class DeployedModelService { `; return ` - ${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)} + ${updateScript} + SELECT artifact_id, artifact_name, name, description, version, created - FROM ${utils.getRegisteredModelsThreePartsName(this._config)} + FROM ${threePartTableName} WHERE artifact_id = SCOPE_IDENTITY(); `; } public getModelContentQuery(model: RegisteredModel): string { + const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || ''); return ` SELECT artifact_content - FROM ${utils.getRegisteredModelsThreePartsName(this._config)} + FROM ${threePartTableName} WHERE artifact_id = ${model.id}; `; } diff --git a/extensions/machine-learning-services/src/modelManagement/interfaces.ts b/extensions/machine-learning-services/src/modelManagement/interfaces.ts index 39a697fcfe..03f143aa2e 100644 --- a/extensions/machine-learning-services/src/modelManagement/interfaces.ts +++ b/extensions/machine-learning-services/src/modelManagement/interfaces.ts @@ -5,6 +5,7 @@ import * as msRest from '@azure/ms-rest-js'; import { Resource } from '@azure/arm-machinelearningservices/esm/models'; +import { DatabaseTable } from '../prediction/interfaces'; /** * An interface representing ListWorkspaceModelResult. @@ -52,6 +53,7 @@ export interface RegisteredModel extends RegisteredModelDetails { id: number; artifactName: string; content?: string; + table: DatabaseTable; } export interface ModelParameter { diff --git a/extensions/machine-learning-services/src/modelManagement/modelConfigRecent.ts b/extensions/machine-learning-services/src/modelManagement/modelConfigRecent.ts new file mode 100644 index 0000000000..760bda5e28 --- /dev/null +++ b/extensions/machine-learning-services/src/modelManagement/modelConfigRecent.ts @@ -0,0 +1,30 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as vscode from 'vscode'; +import * as azdata from 'azdata'; +import { DatabaseTable } from '../prediction/interfaces'; + +const TableConfigName = 'MLS_ModelTableConfigName'; + +export class ModelConfigRecent { + /** + * + */ + constructor(private _memento: vscode.Memento) { + } + + public getModelTable(connection: azdata.connection.ConnectionProfile): DatabaseTable | undefined { + return this._memento.get(this.getKey(connection)); + } + + public storeModelTable(connection: azdata.connection.ConnectionProfile, table: DatabaseTable): void { + this._memento.update(this.getKey(connection), table); + } + + private getKey(connection: azdata.connection.ConnectionProfile): string { + return `${TableConfigName}_${connection.serverName}`; + } +} diff --git a/extensions/machine-learning-services/src/prediction/predictService.ts b/extensions/machine-learning-services/src/prediction/predictService.ts index eebb89bd29..d60f28a663 100644 --- a/extensions/machine-learning-services/src/prediction/predictService.ts +++ b/extensions/machine-learning-services/src/prediction/predictService.ts @@ -10,7 +10,6 @@ import { QueryRunner } from '../common/queryRunner'; import * as utils from '../common/utils'; import { RegisteredModel } from '../modelManagement/interfaces'; import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces'; -import { Config } from '../configurations/config'; /** * Service to make prediction @@ -22,8 +21,7 @@ export class PredictService { */ constructor( private _apiWrapper: ApiWrapper, - private _queryRunner: QueryRunner, - private _config: Config) { + private _queryRunner: QueryRunner) { } /** @@ -54,7 +52,8 @@ export class PredictService { registeredModel.id, predictParams.inputColumns || [], predictParams.outputColumns || [], - predictParams); + predictParams, + registeredModel.table); } else if (filePath) { let modelBytes = await utils.readFileInHex(filePath || ''); query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [], @@ -142,18 +141,20 @@ WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuo modelId: number, columns: PredictColumn[], outputColumns: PredictColumn[], - databaseNameTable: DatabaseTable): string { + sourceTable: DatabaseTable, + importTable: DatabaseTable): string { + const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || ''); return ` DECLARE @model VARBINARY(max) = ( SELECT artifact_content - FROM ${utils.getRegisteredModelsThreePartsName(this._config)} + FROM ${threePartTableName} WHERE artifact_id = ${modelId} ); WITH predict_input AS ( SELECT TOP 1000 ${this.getInputColumnNames(columns, 'pi')} - FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi + FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi ) SELECT ${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')} diff --git a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts index 6d90c665df..3daf6e6850 100644 --- a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts +++ b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts @@ -17,6 +17,8 @@ import * as path from 'path'; import * as os from 'os'; import * as UUID from 'vscode-languageclient/lib/utils/uuid'; import * as fs from 'fs'; +import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent'; +import { DatabaseTable } from '../../prediction/interfaces'; interface TestContext { @@ -24,6 +26,8 @@ interface TestContext { config: TypeMoq.IMock; queryRunner: TypeMoq.IMock; modelClient: TypeMoq.IMock; + recentModels: TypeMoq.IMock; + importTable: DatabaseTable; } function createContext(): TestContext { @@ -32,7 +36,13 @@ function createContext(): TestContext { apiWrapper: TypeMoq.Mock.ofType(ApiWrapper), config: TypeMoq.Mock.ofType(Config), queryRunner: TypeMoq.Mock.ofType(QueryRunner), - modelClient: TypeMoq.Mock.ofType(ModelPythonClient) + modelClient: TypeMoq.Mock.ofType(ModelPythonClient), + recentModels: TypeMoq.Mock.ofType(ModelConfigRecent), + importTable: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } }; } @@ -40,14 +50,20 @@ describe('DeployedModelService', () => { it('getDeployedModels should fail with no connection', async function (): Promise { const testContext = createContext(); let connection: azdata.connection.ConnectionProfile; + let importTable: DatabaseTable = { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + }; testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); let service = new DeployedModelService( testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); - await should(service.getDeployedModels()).rejected(); + testContext.modelClient.object, + testContext.recentModels.object); + await should(service.getDeployedModels(importTable)).rejected(); }); it('getDeployedModels should returns models successfully', async function (): Promise { @@ -61,7 +77,9 @@ describe('DeployedModelService', () => { title: 'title1', description: 'desc1', created: '2018-01-01', - version: '1.1' + version: '1.1', + table: testContext.importTable + } ]; const result = { @@ -106,12 +124,13 @@ describe('DeployedModelService', () => { testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); + testContext.modelClient.object, + testContext.recentModels.object); testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table'); - const actual = await service.getDeployedModels(); + const actual = await service.getDeployedModels(testContext.importTable); should.deepEqual(actual, expected); }); @@ -140,7 +159,8 @@ describe('DeployedModelService', () => { testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); + testContext.modelClient.object, + testContext.recentModels.object); const actual = await service.loadModelParameters(''); should.deepEqual(actual, expected); }); @@ -158,7 +178,8 @@ describe('DeployedModelService', () => { title: 'title1', description: 'desc1', created: '2018-01-01', - version: '1.1' + version: '1.1', + table: testContext.importTable }; const result = { rowCount: 1, @@ -177,7 +198,8 @@ describe('DeployedModelService', () => { testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); + testContext.modelClient.object, + testContext.recentModels.object); testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); @@ -198,7 +220,8 @@ describe('DeployedModelService', () => { title: 'title1', description: 'desc1', created: '2018-01-01', - version: '1.1' + version: '1.1', + table: testContext.importTable }; const row = [ { @@ -247,15 +270,17 @@ describe('DeployedModelService', () => { testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); + testContext.modelClient.object, + testContext.recentModels.object); - testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => { + testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0), TypeMoq.It.isAny())).returns(() => { deployed = true; return Promise.resolve(result); }); testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result); }); + testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table'); @@ -264,7 +289,7 @@ describe('DeployedModelService', () => { try { tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`); await fs.promises.writeFile(tempFilePath, 'test'); - await should(service.deployLocalModel(tempFilePath, model)).resolved(); + await should(service.deployLocalModel(tempFilePath, model, testContext.importTable)).resolved(); } finally { await utils.deleteFile(tempFilePath); @@ -273,31 +298,28 @@ describe('DeployedModelService', () => { it('getConfigureQuery should escape db name', async function (): Promise { const testContext = createContext(); - const dbName = 'curre[n]tDb'; let service = new DeployedModelService( testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); - testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); - testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); - testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + testContext.modelClient.object, + testContext.recentModels.object); + + testContext.importTable.databaseName = 'd[]b'; + testContext.importTable.tableName = 'ta[b]le'; + testContext.importTable.schema = 'dbo'; const expected = ` - IF NOT EXISTS ( - SELECT [name] - FROM sys.databases - WHERE [name] = N'd[]b' - ) - CREATE DATABASE [d[[]]b] - GO - USE [d[[]]b] IF EXISTS - ( SELECT [t.name], [s.name] + ( SELECT t.name, s.name FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id - WHERE [t.name] = 'ta[b]le' - AND [s.name] = 'dbo' + WHERE t.name = 'ta[b]le' + AND s.name = 'dbo' ) BEGIN + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_name') + ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_name] [varchar](256) NOT NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_content') + ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_content] [varbinary](max) NOT NULL IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name') ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version') @@ -315,23 +337,22 @@ describe('DeployedModelService', () => { CREATE TABLE [dbo].[ta[[b]]le]( [artifact_id] [int] IDENTITY(1,1) NOT NULL, [artifact_name] [varchar](256) NOT NULL, - [group_path] [varchar](256) NULL, [artifact_content] [varbinary](max) NOT NULL, [artifact_initial_size] [bigint] NULL, [name] [varchar](256) NULL, [version] [varchar](256) NULL, [created] [datetime] NULL, [description] [varchar](256) NULL, - CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED + CONSTRAINT [ta[[b]]le_artifact_pk] PRIMARY KEY CLUSTERED ( [artifact_id] ASC )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] - ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] + ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] END `; - const actual = service.getConfigureQuery(dbName); - should.equal(actual.indexOf(expected) > 0, true); + const actual = service.getConfigureTableQuery(testContext.importTable); + should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`); }); it('getDeployedModelsQuery should escape db name', async function (): Promise { @@ -340,23 +361,23 @@ describe('DeployedModelService', () => { testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); - testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); - testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); - testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + testContext.modelClient.object, + testContext.recentModels.object); + testContext.importTable.databaseName = 'd[]b'; + testContext.importTable.tableName = 'ta[b]le'; + testContext.importTable.schema = 'dbo'; const expected = ` SELECT artifact_id, artifact_name, name, description, version, created FROM [d[[]]b].[dbo].[ta[[b]]le] WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml' Order by artifact_id `; - const actual = service.getDeployedModelsQuery(); + const actual = service.getDeployedModelsQuery(testContext.importTable); should.deepEqual(expected, actual); }); it('getInsertModelQuery should escape db name', async function (): Promise { const testContext = createContext(); - const dbName = 'curre[n]tDb'; const model: RegisteredModel = { id: 1, @@ -364,28 +385,27 @@ describe('DeployedModelService', () => { title: 'title1', description: 'desc1', created: '2018-01-01', - version: '1.1' + version: '1.1', + table: testContext.importTable }; let service = new DeployedModelService( testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); - testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); - testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); - testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + testContext.modelClient.object, + testContext.recentModels.object); + const expected = ` - Insert into [dbo].[ta[[b]]le] - (artifact_name, group_path, artifact_content, name, version, description) + Insert into [dbo].[tb] + (artifact_name, artifact_content, name, version, description) values ( 'name1', - 'ADS', , 'title1', '1.1', 'desc1')`; - const actual = service.getInsertModelQuery(dbName, model); + const actual = service.getInsertModelQuery(model, testContext.importTable); should.equal(actual.indexOf(expected) > 0, true); }); @@ -398,17 +418,19 @@ describe('DeployedModelService', () => { title: 'title1', description: 'desc1', created: '2018-01-01', - version: '1.1' + version: '1.1', + table: testContext.importTable }; let service = new DeployedModelService( testContext.apiWrapper.object, testContext.config.object, testContext.queryRunner.object, - testContext.modelClient.object); - testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); - testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); - testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + testContext.modelClient.object, + testContext.recentModels.object); + model.table = { + databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo' + }; const expected = ` SELECT artifact_content FROM [d[[]]b].[dbo].[ta[[b]]le] diff --git a/extensions/machine-learning-services/src/test/prediction/predictService.test.ts b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts index 94b3a8cd84..442a2e6fa0 100644 --- a/extensions/machine-learning-services/src/test/prediction/predictService.test.ts +++ b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts @@ -8,7 +8,6 @@ import * as vscode from 'vscode'; import { ApiWrapper } from '../../common/apiWrapper'; import * as TypeMoq from 'typemoq'; import * as should from 'should'; -import { Config } from '../../configurations/config'; import { PredictService } from '../../prediction/predictService'; import { QueryRunner } from '../../common/queryRunner'; import { RegisteredModel } from '../../modelManagement/interfaces'; @@ -22,7 +21,7 @@ import * as fs from 'fs'; interface TestContext { apiWrapper: TypeMoq.IMock; - config: TypeMoq.IMock; + importTable: DatabaseTable; queryRunner: TypeMoq.IMock; } @@ -30,7 +29,11 @@ function createContext(): TestContext { return { apiWrapper: TypeMoq.Mock.ofType(ApiWrapper), - config: TypeMoq.Mock.ofType(Config), + importTable: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + }, queryRunner: TypeMoq.Mock.ofType(QueryRunner) }; } @@ -49,8 +52,7 @@ describe('PredictService', () => { let service = new PredictService( testContext.apiWrapper.object, - testContext.queryRunner.object, - testContext.config.object); + testContext.queryRunner.object); const actual = await service.getDatabaseList(); should.deepEqual(actual, expected); }); @@ -102,8 +104,7 @@ describe('PredictService', () => { testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); let service = new PredictService( testContext.apiWrapper.object, - testContext.queryRunner.object, - testContext.config.object); + testContext.queryRunner.object); const actual = await service.getTableList('db1'); should.deepEqual(actual, expected); }); @@ -160,8 +161,7 @@ describe('PredictService', () => { testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); let service = new PredictService( testContext.apiWrapper.object, - testContext.queryRunner.object, - testContext.config.object); + testContext.queryRunner.object); const actual = await service.getTableColumnsList(table); should.deepEqual(actual, expected); }); @@ -201,13 +201,13 @@ describe('PredictService', () => { title: 'title1', description: 'desc1', created: '2018-01-01', - version: '1.1' + version: '1.1', + table: testContext.importTable }; let service = new PredictService( testContext.apiWrapper.object, - testContext.queryRunner.object, - testContext.config.object); + testContext.queryRunner.object); const document: vscode.TextDocument = { uri: vscode.Uri.parse('file:///usr/home'), @@ -270,8 +270,7 @@ describe('PredictService', () => { let service = new PredictService( testContext.apiWrapper.object, - testContext.queryRunner.object, - testContext.config.object); + testContext.queryRunner.object); const document: vscode.TextDocument = { uri: vscode.Uri.parse('file:///usr/home'), diff --git a/extensions/machine-learning-services/src/test/views/dashboardWidget.test.ts b/extensions/machine-learning-services/src/test/views/dashboardWidget.test.ts index de1fe94b28..2aa6305d58 100644 --- a/extensions/machine-learning-services/src/test/views/dashboardWidget.test.ts +++ b/extensions/machine-learning-services/src/test/views/dashboardWidget.test.ts @@ -34,6 +34,6 @@ describe('Dashboard widget', () => { const dashboard = new DashboardWidget(testContext.apiWrapper.object, ''); dashboard.register(); testContext.onClick.fire(); - testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atMostOnce()); + testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce()); }); }); diff --git a/extensions/machine-learning-services/src/test/views/models/ModelManagementController.test.ts b/extensions/machine-learning-services/src/test/views/models/ModelManagementController.test.ts new file mode 100644 index 0000000000..3202b9cb3a --- /dev/null +++ b/extensions/machine-learning-services/src/test/views/models/ModelManagementController.test.ts @@ -0,0 +1,170 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as should from 'should'; +import * as TypeMoq from 'typemoq'; +import 'mocha'; +import { createContext } from './utils'; +import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces'; +import { azureResource } from '../../../typings/azure-resource'; +import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; +import { WorkspaceModel } from '../../../modelManagement/interfaces'; +import { ModelManagementController } from '../../../views/models/modelManagementController'; +import { DatabaseTable, TableColumn } from '../../../prediction/interfaces'; + +const accounts: azdata.Account[] = [ + { + key: { + accountId: '1', + providerId: '' + }, + displayInfo: { + displayName: 'account', + userId: '', + accountType: '', + contextualDisplayName: '' + }, + isStale: false, + properties: [] + } +]; +const subscriptions: azureResource.AzureResourceSubscription[] = [ + { + name: 'subscription', + id: '2' + } +]; +const groups: azureResource.AzureResourceResourceGroup[] = [ + { + name: 'group', + id: '3' + } +]; +const workspaces: Workspace[] = [ + { + name: 'workspace', + id: '4' + } +]; +const models: WorkspaceModel[] = [ + { + id: '5', + name: 'model' + } +]; +const localModels: RegisteredModel[] = [ + { + id: 1, + artifactName: 'model', + title: 'model', + table: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } + } +]; + +const dbNames: string[] = [ + 'db1', + 'db2' +]; +const tableNames: DatabaseTable[] = [ + { + databaseName: 'db1', + schema: 'dbo', + tableName: 'tb1' + }, + { + databaseName: 'db1', + tableName: 'tb2', + schema: 'dbo' + } +]; +const columnNames: TableColumn[] = [ + { + columnName: 'c1', + dataType: 'int' + }, + { + columnName: 'c2', + dataType: 'varchar' + } +]; +const modelParameters: ModelParameters = { + inputs: [ + { + 'name': 'p1', + 'type': 'int' + }, + { + 'name': 'p2', + 'type': 'varchar' + } + ], + outputs: [ + { + 'name': 'o1', + 'type': 'int' + } + ] +}; +describe('Model Controller', () => { + + it('Should open deploy model wizard successfully ', async function (): Promise { + let testContext = createContext(); + + + let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object); + testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({ + databaseName: 'db', + tableName: 'table', + schema: 'dbo' + })); + testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels)); + testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames)); + testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames)); + testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts)); + testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions)); + testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups)); + testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces)); + testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models)); + + const view = await controller.registerModel(undefined); + should.notEqual(view, undefined); + }); + + it('Should open predict wizard successfully ', async function (): Promise { + let testContext = createContext(); + + + let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object); + testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({ + databaseName: 'db', + tableName: 'table', + schema: 'dbo' + })); + testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels)); + testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve([ + 'db', 'db1' + ])); + testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve([ + { tableName: 'tb', databaseName: 'db', schema: 'dbo' } + ])); + testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts)); + testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions)); + testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups)); + testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces)); + testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models)); + testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames)); + testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters)); + testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file')); + testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file')); + + const view = await controller.predictModel(); + should.notEqual(view, undefined); + }); +}); diff --git a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts index 77d4b8dc64..817885971c 100644 --- a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts @@ -34,6 +34,11 @@ describe('Predict Wizard', () => { let testContext = createContext(); let view = new PredictWizard(testContext.apiWrapper.object, ''); + view.importTable = { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + }; await view.open(); let accounts: azdata.Account[] = [ { @@ -79,7 +84,12 @@ describe('Predict Wizard', () => { { id: 1, artifactName: 'model', - title: 'model' + title: 'model', + table: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } } ]; const dbNames: string[] = [ diff --git a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts index 0bf7a7b9de..e8542183af 100644 --- a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts @@ -7,21 +7,25 @@ import * as azdata from 'azdata'; import * as should from 'should'; import 'mocha'; import { createContext } from './utils'; -import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType } from '../../../views/models/modelViewBase'; +import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase'; import { RegisteredModel } from '../../../modelManagement/interfaces'; import { azureResource } from '../../../typings/azure-resource'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { ViewBase } from '../../../views/viewBase'; import { WorkspaceModel } from '../../../modelManagement/interfaces'; -import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard'; +import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard'; describe('Register Model Wizard', () => { it('Should create view components successfully ', async function (): Promise { let testContext = createContext(); - let view = new RegisterModelWizard(testContext.apiWrapper.object, ''); + let view = new ImportModelWizard(testContext.apiWrapper.object, ''); + view.importTable = { + databaseName: 'db', + tableName: 'table', + schema: 'dbo' + }; await view.open(); - await view.refresh(); should.notEqual(view.wizardView, undefined); should.notEqual(view.modelSourcePage, undefined); }); @@ -29,7 +33,12 @@ describe('Register Model Wizard', () => { it('Should load data successfully ', async function (): Promise { let testContext = createContext(); - let view = new RegisterModelWizard(testContext.apiWrapper.object, ''); + let view = new ImportModelWizard(testContext.apiWrapper.object, ''); + view.importTable = { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + }; await view.open(); let accounts: azdata.Account[] = [ { @@ -75,12 +84,27 @@ describe('Register Model Wizard', () => { { id: 1, artifactName: 'model', - title: 'model' + title: 'model', + table: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } } ]; view.on(ListModelsEventName, () => { view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels }); }); + view.on(ListDatabaseNamesEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [ + 'db', 'db1' + ] }); + }); + view.on(ListTableNamesEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [ + 'tb', 'tb1' + ] }); + }); view.on(ListAccountsEventName, () => { view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts }); }); diff --git a/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts b/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts index 9c4d84c33f..0a40b61de6 100644 --- a/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registeredModelsDialog.test.ts @@ -6,7 +6,7 @@ import * as should from 'should'; import 'mocha'; import { createContext } from './utils'; -import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog'; +import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog'; import { ListModelsEventName } from '../../../views/models/modelViewBase'; import { RegisteredModel } from '../../../modelManagement/interfaces'; import { ViewBase } from '../../../views/viewBase'; @@ -15,7 +15,7 @@ describe('Registered Models Dialog', () => { it('Should create view components successfully ', async function (): Promise { let testContext = createContext(); - let view = new RegisteredModelsDialog(testContext.apiWrapper.object, ''); + let view = new ManageModelsDialog(testContext.apiWrapper.object, ''); view.open(); should.notEqual(view.dialogView, undefined); @@ -25,13 +25,18 @@ describe('Registered Models Dialog', () => { it('Should load data successfully ', async function (): Promise { let testContext = createContext(); - let view = new RegisteredModelsDialog(testContext.apiWrapper.object, ''); + let view = new ManageModelsDialog(testContext.apiWrapper.object, ''); view.open(); let models: RegisteredModel[] = [ { id: 1, artifactName: 'model', - title: '' + title: '', + table: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } } ]; view.on(ListModelsEventName, () => { diff --git a/extensions/machine-learning-services/src/test/views/models/utils.ts b/extensions/machine-learning-services/src/test/views/models/utils.ts index c65150a998..986f343f18 100644 --- a/extensions/machine-learning-services/src/test/views/models/utils.ts +++ b/extensions/machine-learning-services/src/test/views/models/utils.ts @@ -9,11 +9,17 @@ import * as TypeMoq from 'typemoq'; import { ApiWrapper } from '../../../common/apiWrapper'; import { createViewContext } from '../utils'; import { ModelViewBase } from '../../../views/models/modelViewBase'; +import { AzureModelRegistryService } from '../../../modelManagement/azureModelRegistryService'; +import { DeployedModelService } from '../../../modelManagement/deployedModelService'; +import { PredictService } from '../../../prediction/predictService'; export interface TestContext { apiWrapper: TypeMoq.IMock; view: azdata.ModelView; onClick: vscode.EventEmitter; + azureModelService: TypeMoq.IMock; + deployModelService: TypeMoq.IMock; + predictService: TypeMoq.IMock; } export class ParentDialog extends ModelViewBase { @@ -36,6 +42,9 @@ export function createContext(): TestContext { return { apiWrapper: viewTestContext.apiWrapper, view: viewTestContext.view, - onClick: viewTestContext.onClick + onClick: viewTestContext.onClick, + azureModelService: TypeMoq.Mock.ofType(AzureModelRegistryService), + deployModelService: TypeMoq.Mock.ofType(DeployedModelService), + predictService: TypeMoq.Mock.ofType(PredictService) }; } diff --git a/extensions/machine-learning-services/src/test/views/utils.ts b/extensions/machine-learning-services/src/test/views/utils.ts index ede30eca63..b3b4ece492 100644 --- a/extensions/machine-learning-services/src/test/views/utils.ts +++ b/extensions/machine-learning-services/src/test/views/utils.ts @@ -31,6 +31,11 @@ export function createViewContext(): ViewTestContext { let button: azdata.ButtonComponent = Object.assign({}, componentBase, { onDidClick: onClick.event }); + let link: azdata.HyperlinkComponent = Object.assign({}, componentBase, { + onDidClick: onClick.event, + label: '', + url: '' + }); let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, { checked: true, onDidClick: onClick.event @@ -61,6 +66,11 @@ export function createViewContext(): ViewTestContext { withProperties: () => buttonBuilder, withValidation: () => buttonBuilder }; + let hyperLinkBuilder: azdata.ComponentBuilder = { + component: () => link, + withProperties: () => hyperLinkBuilder, + withValidation: () => hyperLinkBuilder + }; let radioButtonBuilder: azdata.ComponentBuilder = { component: () => radioButton, withProperties: () => radioButtonBuilder, @@ -72,7 +82,7 @@ export function createViewContext(): ViewTestContext { withValidation: () => checkBoxBuilder }; let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, { - onTextChanged: undefined!, + onTextChanged: onClick.event!, onEnterKeyPressed: undefined!, value: '' }); @@ -216,7 +226,7 @@ export function createViewContext(): ViewTestContext { toolbarContainer: undefined!, loadingComponent: () => loadingBuilder, fileBrowserTree: undefined!, - hyperlink: undefined!, + hyperlink: () => hyperLinkBuilder, tabbedPanel: undefined!, separator: undefined! } diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsComponent.ts similarity index 56% rename from extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts rename to extensions/machine-learning-services/src/views/models/manageModels/currentModelsComponent.ts index c05aa078d1..7dc339d7da 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsComponent.ts @@ -10,21 +10,24 @@ import { ModelViewBase } from '../modelViewBase'; import { CurrentModelsTable } from './currentModelsTable'; import { ApiWrapper } from '../../../common/apiWrapper'; import { IPageView } from '../../interfaces'; +import { TableSelectionComponent } from '../tableSelectionComponent'; +import { RegisteredModel } from '../../../modelManagement/interfaces'; /** * View to render current registered models */ -export class CurrentModelsPage extends ModelViewBase implements IPageView { +export class CurrentModelsComponent extends ModelViewBase implements IPageView { private _tableComponent: azdata.Component | undefined; private _dataTable: CurrentModelsTable | undefined; private _loader: azdata.LoadingComponent | undefined; + private _tableSelectionComponent: TableSelectionComponent | undefined; /** * * @param apiWrapper Creates new view * @param parent page parent */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = false) { super(apiWrapper, parent.root, parent); } @@ -33,11 +36,17 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView { * @param modelBuilder register the components */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._dataTable = new CurrentModelsTable(this._apiWrapper, this, false); + this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false); + this._tableSelectionComponent.registerComponent(modelBuilder); + this._tableSelectionComponent.onSelectedChanged(async () => { + await this.onTableSelected(); + }); + this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect); this._dataTable.registerComponent(modelBuilder); this._tableComponent = this._dataTable.component; let formModelBuilder = modelBuilder.formContainer(); + this._tableSelectionComponent.addComponents(formModelBuilder); if (this._tableComponent) { formModelBuilder.addFormItem({ @@ -54,6 +63,20 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView { return this._loader; } + public addComponents(formBuilder: azdata.FormBuilder) { + if (this._tableSelectionComponent && this._dataTable) { + this._tableSelectionComponent.addComponents(formBuilder); + this._dataTable.addComponents(formBuilder); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + if (this._tableSelectionComponent && this._dataTable) { + this._tableSelectionComponent.removeComponents(formBuilder); + this._dataTable.removeComponents(formBuilder); + } + } + /** * Returns the component */ @@ -68,6 +91,9 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView { await this.onLoading(); try { + if (this._tableSelectionComponent) { + this._tableSelectionComponent.refresh(); + } await this._dataTable?.refresh(); } catch (err) { this.showErrorMessage(constants.getErrorMessage(err)); @@ -76,6 +102,31 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView { } } + public get data(): RegisteredModel[] | undefined { + return this._dataTable?.data; + } + + private async onTableSelected(): Promise { + if (this._tableSelectionComponent?.data) { + this.importTable = this._tableSelectionComponent?.data; + await this.storeImportConfigTable(); + await this._dataTable?.refresh(); + } + } + + public get modelTable(): CurrentModelsTable | undefined { + return this._dataTable; + } + + /** + * disposes the view + */ + public async disposeComponent(): Promise { + if (this._dataTable) { + await this._dataTable.disposeComponent(); + } + } + /** * returns the title of the page */ diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsTable.ts similarity index 97% rename from extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts rename to extensions/machine-learning-services/src/views/models/manageModels/currentModelsTable.ts index 1da086f8f3..73c3a07de2 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/currentModelsTable.ts @@ -134,7 +134,11 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< if (this._table) { let models: RegisteredModel[] | undefined; - models = await this.listModels(); + if (this.importTable) { + models = await this.listModels(this.importTable); + } else { + this.showErrorMessage('No import table'); + } let tableData: any[][] = []; if (models) { diff --git a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts b/extensions/machine-learning-services/src/views/models/manageModels/importModelWizard.ts similarity index 85% rename from extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts rename to extensions/machine-learning-services/src/views/models/manageModels/importModelWizard.ts index d965ee6c19..709fcd89bf 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/importModelWizard.ts @@ -14,15 +14,17 @@ import { WizardView } from '../../wizardView'; import { ModelSourcePage } from '../modelSourcePage'; import { ModelDetailsPage } from '../modelDetailsPage'; import { ModelBrowsePage } from '../modelBrowsePage'; +import { ModelImportLocationPage } from './modelmportLocationPage'; /** * Wizard to register a model */ -export class RegisterModelWizard extends ModelViewBase { +export class ImportModelWizard extends ModelViewBase { public modelSourcePage: ModelSourcePage | undefined; public modelBrowsePage: ModelBrowsePage | undefined; public modelDetailsPage: ModelDetailsPage | undefined; + public modelImportTargetPage: ModelImportLocationPage | undefined; public wizardView: WizardView | undefined; private _parentView: ModelViewBase | undefined; @@ -41,9 +43,10 @@ export class RegisterModelWizard extends ModelViewBase { this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this); this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this); this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this); + this.modelImportTargetPage = new ModelImportLocationPage(this._apiWrapper, this); this.wizardView = new WizardView(this._apiWrapper); - let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]); + let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelImportTargetPage, this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]); this.mainViewPanel = wizard; wizard.doneButton.label = constants.azureRegisterModel; @@ -61,7 +64,8 @@ export class RegisterModelWizard extends ModelViewBase { wizard.cancelButton.enabled = true; wizard.backButton.enabled = true; if (this._parentView) { - await this._parentView?.refresh(); + this._parentView.importTable = this.importTable; + await this._parentView.refresh(); } return result; @@ -87,10 +91,11 @@ export class RegisterModelWizard extends ModelViewBase { private async registerModel(): Promise { try { if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) { - await this.registerLocalModel(this.modelsViewData); + await this.importLocalModel(this.modelsViewData); } else { - await this.registerAzureModel(this.modelsViewData); + await this.importAzureModel(this.modelsViewData); } + await this.storeImportConfigTable(); this.showInfoMessage(constants.modelRegisteredSuccessfully); return true; } catch (error) { diff --git a/extensions/machine-learning-services/src/views/models/registerModels/registeredModelsDialog.ts b/extensions/machine-learning-services/src/views/models/manageModels/manageModelsDialog.ts similarity index 81% rename from extensions/machine-learning-services/src/views/models/registerModels/registeredModelsDialog.ts rename to extensions/machine-learning-services/src/views/models/manageModels/manageModelsDialog.ts index af45cec44b..2dee324da2 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/registeredModelsDialog.ts +++ b/extensions/machine-learning-services/src/views/models/manageModels/manageModelsDialog.ts @@ -3,7 +3,7 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import { CurrentModelsPage } from './currentModelsPage'; +import { CurrentModelsComponent } from './currentModelsComponent'; import { ModelViewBase, RegisterModelEventName } from '../modelViewBase'; import * as constants from '../../../common/constants'; @@ -13,7 +13,7 @@ import { DialogView } from '../../dialogView'; /** * Dialog to render registered model views */ -export class RegisteredModelsDialog extends ModelViewBase { +export class ManageModelsDialog extends ModelViewBase { constructor( apiWrapper: ApiWrapper, @@ -22,18 +22,18 @@ export class RegisteredModelsDialog extends ModelViewBase { this.dialogView = new DialogView(this._apiWrapper); } public dialogView: DialogView; - public currentLanguagesTab: CurrentModelsPage | undefined; + public currentLanguagesTab: CurrentModelsComponent | undefined; /** * Opens a dialog to manage packages used by notebooks. */ public open(): void { - this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this); + this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this); let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle); registerModelButton.onClick(async () => { - await this.sendDataRequest(RegisterModelEventName); + await this.sendDataRequest(RegisterModelEventName, this.currentLanguagesTab?.modelTable?.importTable); }); let dialog = this.dialogView.createDialog(constants.registerModelTitle, [this.currentLanguagesTab]); diff --git a/extensions/machine-learning-services/src/views/models/manageModels/modelmportLocationPage.ts b/extensions/machine-learning-services/src/views/models/manageModels/modelmportLocationPage.ts new file mode 100644 index 0000000000..8f8098b360 --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/manageModels/modelmportLocationPage.ts @@ -0,0 +1,98 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import { ModelViewBase } from '../modelViewBase'; +import { ApiWrapper } from '../../../common/apiWrapper'; +import * as constants from '../../../common/constants'; +import { IPageView, IDataComponent } from '../../interfaces'; +import { TableSelectionComponent } from '../tableSelectionComponent'; +import { DatabaseTable } from '../../../prediction/interfaces'; + +/** + * View to pick model source + */ +export class ModelImportLocationPage extends ModelViewBase implements IPageView, IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _formBuilder: azdata.FormBuilder | undefined; + public tableSelectionComponent: TableSelectionComponent | undefined; + + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { + super(apiWrapper, parent.root, parent); + } + + /** + * + * @param modelBuilder Register components + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + + this._formBuilder = modelBuilder.formContainer(); + this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, true); + this.tableSelectionComponent.onSelectedChanged(async () => { + await this.onTableSelected(); + }); + this.tableSelectionComponent.registerComponent(modelBuilder); + this.tableSelectionComponent.addComponents(this._formBuilder); + this._form = this._formBuilder.component(); + return this._form; + } + + private async onTableSelected(): Promise { + if (this.tableSelectionComponent?.data) { + this.importTable = this.tableSelectionComponent?.data; + //this.sendRequest(StoreImportTableEventName, this.importTable); + } + } + + /** + * Returns selected data + */ + public get data(): DatabaseTable | undefined { + return this.tableSelectionComponent?.data; + } + + /** + * Returns the component + */ + public get component(): azdata.Component | undefined { + return this._form; + } + + /** + * Refreshes the view + */ + public async refresh(): Promise { + if (this.tableSelectionComponent) { + await this.tableSelectionComponent.refresh(); + } + } + + /** + * Returns page title + */ + public get title(): string { + return constants.modelImportTargetPageTitle; + } + + public async disposePage(): Promise { + } + + public async validate(): Promise { + let validated = false; + + if (this.data?.databaseName && this.data?.tableName) { + validated = true; + validated = await this.verifyImportConfigTable(this.data); + if (!validated) { + this.showErrorMessage(constants.invalidImportTableSchemaError(this.data?.databaseName, this.data?.tableName)); + } + } else { + this.showErrorMessage(constants.invalidImportTableError(this.data?.databaseName, this.data?.tableName)); + } + return validated; + } +} diff --git a/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts b/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts index 3bd3b6fc3a..146f1b1281 100644 --- a/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts +++ b/extensions/machine-learning-services/src/views/models/modelBrowsePage.ts @@ -10,8 +10,8 @@ import * as constants from '../../common/constants'; import { IPageView, IDataComponent } from '../interfaces'; import { LocalModelsComponent } from './localModelsComponent'; import { AzureModelsComponent } from './azureModelsComponent'; -import { CurrentModelsTable } from './registerModels/currentModelsTable'; import * as utils from '../../common/utils'; +import { CurrentModelsComponent } from './manageModels/currentModelsComponent'; /** * View to pick model source @@ -19,10 +19,11 @@ import * as utils from '../../common/utils'; export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent { private _form: azdata.FormContainer | undefined; + private _title: string = constants.modelSourcePageTitle; private _formBuilder: azdata.FormBuilder | undefined; public localModelsComponent: LocalModelsComponent | undefined; public azureModelsComponent: AzureModelsComponent | undefined; - public registeredModelsComponent: CurrentModelsTable | undefined; + public registeredModelsComponent: CurrentModelsComponent | undefined; constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) { super(apiWrapper, parent.root, parent); @@ -39,7 +40,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo this.localModelsComponent.registerComponent(modelBuilder); this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect); this.azureModelsComponent.registerComponent(modelBuilder); - this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect); + this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, this._multiSelect); this.registeredModelsComponent.registerComponent(modelBuilder); this.refresh(); this._form = this._formBuilder.component(); @@ -88,16 +89,29 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo this.registeredModelsComponent.addComponents(this._formBuilder); await this.registeredModelsComponent.refresh(); } - } } + this.loadTitle(); + } + + private loadTitle(): void { + if (this.modelSourceType === ModelSourceType.Local) { + this._title = 'Upload model file'; + } else if (this.modelSourceType === ModelSourceType.Azure) { + this._title = 'Import from Azure Machine Learning'; + + } else if (this.modelSourceType === ModelSourceType.RegisteredModels) { + this._title = 'Select imported model'; + } else { + this._title = constants.modelSourcePageTitle; + } } /** * Returns page title */ public get title(): string { - return constants.modelSourcePageTitle; + return this._title; } public validate(): Promise { @@ -117,6 +131,10 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo return Promise.resolve(validated); } + public onEnter(): Promise { + return Promise.resolve(); + } + public async onLeave(): Promise { this.modelsViewData = []; if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) { @@ -128,7 +146,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo modelDetails: { title: fileName, fileName: fileName - } + }, + targetImportTable: this.importTable }; }); } @@ -147,7 +166,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo modelDetails: { title: x.model?.name || '', fileName: x.model?.name - } + }, + targetImportTable: this.importTable }; }); } @@ -159,7 +179,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo modelData: x, modelDetails: { title: '' - } + }, + targetImportTable: this.importTable }; }); } diff --git a/extensions/machine-learning-services/src/views/models/modelManagementController.ts b/extensions/machine-learning-services/src/views/models/modelManagementController.ts index b31884008a..b6b60b5fd4 100644 --- a/extensions/machine-learning-services/src/views/models/modelManagementController.ts +++ b/extensions/machine-learning-services/src/views/models/modelManagementController.ts @@ -12,15 +12,15 @@ import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { RegisteredModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces'; import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { DeployedModelService } from '../../modelManagement/deployedModelService'; -import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog'; +import { ManageModelsDialog } from './manageModels/manageModelsDialog'; import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, - ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData + ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName } from './modelViewBase'; import { ControllerBase } from '../controllerBase'; -import { RegisterModelWizard } from './registerModels/registerModelWizard'; +import { ImportModelWizard } from './manageModels/importModelWizard'; import * as fs from 'fs'; import * as constants from '../../common/constants'; import { PredictWizard } from './prediction/predictWizard'; @@ -51,11 +51,16 @@ export class ModelManagementController extends ControllerBase { * @param apiWrapper apiWrapper * @param root root folder path */ - public async registerModel(parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise { + public async registerModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise { controller = controller || this; apiWrapper = apiWrapper || this._apiWrapper; root = root || this._root; - let view = new RegisterModelWizard(apiWrapper, root, parent); + let view = new ImportModelWizard(apiWrapper, root, parent); + if (importTable) { + view.importTable = importTable; + } else { + view.importTable = await controller._registeredModelService.getRecentImportTable(); + } controller.registerEvents(view); @@ -72,6 +77,7 @@ export class ModelManagementController extends ControllerBase { public async predictModel(): Promise { let view = new PredictWizard(this._apiWrapper, this._root); + view.importTable = await this._registeredModelService.getRecentImportTable(); this.registerEvents(view); view.on(LoadModelParametersEventName, async () => { @@ -117,17 +123,18 @@ export class ModelManagementController extends ControllerBase { await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService , azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace); }); - - view.on(ListModelsEventName, async () => { - await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService); + view.on(ListModelsEventName, async (args) => { + const table = args; + await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService, table); }); view.on(RegisterLocalModelEventName, async (arg) => { let models = arg; await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models); view.refresh(); }); - view.on(RegisterModelEventName, async () => { - await this.executeAction(view, RegisterModelEventName, this.registerModel, view, this, this._apiWrapper, this._root); + view.on(RegisterModelEventName, async (args) => { + const importTable = args; + await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root); }); view.on(RegisterAzureModelEventName, async (arg) => { let models = arg; @@ -161,6 +168,16 @@ export class ModelManagementController extends ControllerBase { await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService, model); }); + view.on(StoreImportTableEventName, async (arg) => { + let importTable = arg; + await this.executeAction(view, StoreImportTableEventName, this.storeImportTable, this._registeredModelService, + importTable); + }); + view.on(VerifyImportTableEventName, async (arg) => { + let importTable = arg; + await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService, + importTable); + }); view.on(SourceModelSelectedEventName, (arg) => { view.modelSourceType = arg; view.refresh(); @@ -170,8 +187,14 @@ export class ModelManagementController extends ControllerBase { /** * Opens the dialog for model management */ - public async manageRegisteredModels(): Promise { - let view = new RegisteredModelsDialog(this._apiWrapper, this._root); + public async manageRegisteredModels(importTable?: DatabaseTable): Promise { + let view = new ManageModelsDialog(this._apiWrapper, this._root); + + if (importTable) { + view.importTable = importTable; + } else { + view.importTable = await this._registeredModelService.getRecentImportTable(); + } // Register events // @@ -202,8 +225,8 @@ export class ModelManagementController extends ControllerBase { return await service.getWorkspaces(account, subscription, group); } - private async getRegisteredModels(registeredModelService: DeployedModelService): Promise { - return registeredModelService.getDeployedModels(); + private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise { + return registeredModelService.getDeployedModels(table); } private async getAzureModels( @@ -221,9 +244,13 @@ export class ModelManagementController extends ControllerBase { private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise { if (models) { await Promise.all(models.map(async (model) => { - const localModel = model.modelData; - if (localModel) { - await service.deployLocalModel(localModel, model.modelDetails); + if (model && model.targetImportTable) { + const localModel = model.modelData; + if (localModel) { + await service.deployLocalModel(localModel, model.modelDetails, model.targetImportTable); + } + } else { + throw Error(constants.invalidModelToRegisterError); } })); } else { @@ -240,35 +267,39 @@ export class ModelManagementController extends ControllerBase { } await Promise.all(models.map(async (model) => { - const azureModel = model.modelData; - if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) { - let filePath: string | undefined; - try { - const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group, - azureModel.workspace, azureModel.model); - if (filePath) { - await service.deployLocalModel(filePath, model.modelDetails); - } else { - throw Error(constants.invalidModelToRegisterError); - } - } finally { - if (filePath) { - await fs.promises.unlink(filePath); + if (model && model.targetImportTable) { + const azureModel = model.modelData; + if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) { + let filePath: string | undefined; + try { + const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group, + azureModel.workspace, azureModel.model); + if (filePath) { + await service.deployLocalModel(filePath, model.modelDetails, model.targetImportTable); + } else { + throw Error(constants.invalidModelToRegisterError); + } + } finally { + if (filePath) { + await fs.promises.unlink(filePath); + } } } + } else { + throw Error(constants.invalidModelToRegisterError); } })); } - public async getDatabaseList(predictService: PredictService): Promise { + private async getDatabaseList(predictService: PredictService): Promise { return await predictService.getDatabaseList(); } - public async getTableList(predictService: PredictService, databaseName: string): Promise { + private async getTableList(predictService: PredictService, databaseName: string): Promise { return await predictService.getTableList(databaseName); } - public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise { + private async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise { return await predictService.getTableColumnsList(databaseTable); } @@ -285,6 +316,22 @@ export class ModelManagementController extends ControllerBase { return result; } + private async storeImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise { + if (table) { + await registeredModelService.storeRecentImportTable(table); + } else { + throw Error(constants.invalidImportTableError(undefined, undefined)); + } + } + + private async verifyImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise { + if (table) { + return await registeredModelService.verifyConfigTable(table); + } else { + throw Error(constants.invalidImportTableError(undefined, undefined)); + } + } + private async downloadRegisteredModel( registeredModelService: DeployedModelService, model: RegisteredModel | undefined): Promise { diff --git a/extensions/machine-learning-services/src/views/models/modelViewBase.ts b/extensions/machine-learning-services/src/views/models/modelViewBase.ts index f1c4cd4b2b..65ddc53270 100644 --- a/extensions/machine-learning-services/src/views/models/modelViewBase.ts +++ b/extensions/machine-learning-services/src/views/models/modelViewBase.ts @@ -37,6 +37,7 @@ export interface ModelViewData { modelFile?: string; modelData: AzureModelResource | string | RegisteredModel; modelDetails?: RegisteredModelDetails; + targetImportTable?: DatabaseTable; } // Event names @@ -58,6 +59,8 @@ export const PredictModelEventName = 'predictModel'; export const RegisterModelEventName = 'registerModel'; export const SourceModelSelectedEventName = 'sourceModelSelected'; export const LoadModelParametersEventName = 'loadModelParameters'; +export const StoreImportTableEventName = 'storeImportTable'; +export const VerifyImportTableEventName = 'verifyImportTable'; /** * Base class for all model management views @@ -66,6 +69,7 @@ export abstract class ModelViewBase extends ViewBase { private _modelSourceType: ModelSourceType = ModelSourceType.Local; private _modelsViewData: ModelViewData[] = []; + private _importTable: DatabaseTable | undefined; constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) { super(apiWrapper, root, parent); @@ -88,7 +92,9 @@ export abstract class ModelViewBase extends ViewBase { PredictModelEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, - LoadModelParametersEventName]); + LoadModelParametersEventName, + StoreImportTableEventName, + VerifyImportTableEventName]); } /** @@ -109,8 +115,8 @@ export abstract class ModelViewBase extends ViewBase { /** * list registered models */ - public async listModels(): Promise { - return await this.sendDataRequest(ListModelsEventName); + public async listModels(table: DatabaseTable): Promise { + return await this.sendDataRequest(ListModelsEventName, table); } /** @@ -156,7 +162,7 @@ export abstract class ModelViewBase extends ViewBase { * registers local model * @param localFilePath local file path */ - public async registerLocalModel(models: ModelViewData[]): Promise { + public async importLocalModel(models: ModelViewData[]): Promise { return await this.sendDataRequest(RegisterLocalModelEventName, models); } @@ -187,10 +193,24 @@ export abstract class ModelViewBase extends ViewBase { * registers azure model * @param args azure resource */ - public async registerAzureModel(models: ModelViewData[]): Promise { + public async importAzureModel(models: ModelViewData[]): Promise { return await this.sendDataRequest(RegisterAzureModelEventName, models); } + /** + * Stores the name of the table as recent config table for importing models + */ + public async storeImportConfigTable(): Promise { + await this.sendRequest(StoreImportTableEventName, this.importTable); + } + + /** + * Verifies if table is valid to import models to + */ + public async verifyImportConfigTable(table: DatabaseTable): Promise { + return await this.sendDataRequest(VerifyImportTableEventName, table); + } + /** * registers azure model * @param args azure resource @@ -240,7 +260,7 @@ export abstract class ModelViewBase extends ViewBase { } /** - * Sets model source type + * Sets model data */ public set modelsViewData(value: ModelViewData[]) { if (this.parent) { @@ -251,7 +271,7 @@ export abstract class ModelViewBase extends ViewBase { } /** - * Returns model source type + * Returns model data */ public get modelsViewData(): ModelViewData[] { if (this.parent) { @@ -261,6 +281,28 @@ export abstract class ModelViewBase extends ViewBase { } } + /** + * Sets import table + */ + public set importTable(value: DatabaseTable | undefined) { + if (this.parent) { + this.parent.importTable = value; + } else { + this._importTable = value; + } + } + + /** + * Returns import table + */ + public get importTable(): DatabaseTable | undefined { + if (this.parent) { + return this.parent.importTable; + } else { + return this._importTable; + } + } + /** * lists azure workspaces * @param account azure account diff --git a/extensions/machine-learning-services/src/views/models/prediction/inputColumnsComponent.ts b/extensions/machine-learning-services/src/views/models/prediction/inputColumnsComponent.ts index bdeca703a4..13d7d0db6a 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/inputColumnsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/inputColumnsComponent.ts @@ -11,6 +11,7 @@ import { IDataComponent } from '../../interfaces'; import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces'; import { ModelParameters } from '../../../modelManagement/interfaces'; import { ColumnsTable } from './columnsTable'; +import { TableSelectionComponent } from '../tableSelectionComponent'; /** * View to render filters to pick an azure resource @@ -18,14 +19,10 @@ import { ColumnsTable } from './columnsTable'; export class InputColumnsComponent extends ModelViewBase implements IDataComponent { private _form: azdata.FormContainer | undefined; - private _databases: azdata.DropDownComponent | undefined; - private _tables: azdata.DropDownComponent | undefined; + private _tableSelectionComponent: TableSelectionComponent | undefined; private _columns: ColumnsTable | undefined; - private _dbNames: string[] = []; - private _tableNames: DatabaseTable[] = []; private _modelParameters: ModelParameters | undefined; - private _dbTableComponent: azdata.FlexContainer | undefined; - private tableMaxLength = this.componentMaxLength * 2 + 70; + /** * Creates a new view */ @@ -38,53 +35,15 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone * @param modelBuilder model builder */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._databases = modelBuilder.dropDown().withProperties({ - width: this.componentMaxLength - }).component(); - this._tables = modelBuilder.dropDown().withProperties({ - width: this.componentMaxLength - }).component(); - this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this); - - this._databases.onValueChanged(async () => { - await this.onDatabaseSelected(); - }); - - this._tables.onValueChanged(async () => { + this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false); + this._tableSelectionComponent.registerComponent(modelBuilder); + this._tableSelectionComponent.onSelectedChanged(async () => { await this.onTableSelected(); }); - - const databaseForm = modelBuilder.formContainer().withFormItems([{ - title: constants.columnDatabase, - component: this._databases - }]).withLayout({ - padding: '0px' - }).component(); - const tableForm = modelBuilder.formContainer().withFormItems([{ - title: constants.columnTable, - component: this._tables - }]).withLayout({ - padding: '0px' - }).component(); - this._dbTableComponent = modelBuilder.flexContainer().withItems([ - databaseForm, - tableForm - ], { - flex: '0 0 auto', - CSSStyles: { - 'align-items': 'flex-start' - } - }).withLayout({ - flexFlow: 'row', - justifyContent: 'space-between', - width: this.tableMaxLength - }).component(); + this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this); this._form = modelBuilder.formContainer().withFormItems([{ - title: '', - component: this._dbTableComponent - }, { title: constants.inputColumns, component: this._columns.component }]).component(); @@ -92,10 +51,10 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone } public addComponents(formBuilder: azdata.FormBuilder) { - if (this._columns && this._dbTableComponent) { + if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) { formBuilder.addFormItems([{ title: '', - component: this._dbTableComponent + component: this._tableSelectionComponent.component }, { title: constants.inputColumns, component: this._columns.component @@ -104,10 +63,10 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone } public removeComponents(formBuilder: azdata.FormBuilder) { - if (this._columns && this._dbTableComponent) { + if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) { formBuilder.removeFormItem({ title: '', - component: this._dbTableComponent + component: this._tableSelectionComponent.component }); formBuilder.removeFormItem({ title: constants.inputColumns, @@ -136,12 +95,9 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone * loads data in the components */ public async loadData(): Promise { - this._dbNames = await this.listDatabaseNames(); - if (this._databases && this._dbNames && this._dbNames.length > 0) { - this._databases.values = this._dbNames; - this._databases.value = this._dbNames[0]; + if (this._tableSelectionComponent) { + this._tableSelectionComponent.refresh(); } - await this.onDatabaseSelected(); } public set modelParameters(value: ModelParameters) { @@ -167,31 +123,14 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone await this.loadData(); } - private async onDatabaseSelected(): Promise { - this._tableNames = await this.listTableNames(this.databaseName || ''); - if (this._tables && this._tableNames && this._tableNames.length > 0) { - this._tables.values = this._tableNames.map(t => this.getTableFullName(t)); - this._tables.value = this.getTableFullName(this._tableNames[0]); - } - await this.onTableSelected(); - } - - private getTableFullName(table: DatabaseTable): string { - return `${table.schema}.${table.tableName}`; - } - private async onTableSelected(): Promise { this._columns?.loadInputs(this._modelParameters, this.databaseTable); } - private get databaseName(): string | undefined { - return this._databases?.value; - } - private get databaseTable(): DatabaseTable { - let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._tables?.value); + let selectedItem = this._tableSelectionComponent?.data; return { - databaseName: this.databaseName, + databaseName: selectedItem?.databaseName, tableName: selectedItem?.tableName, schema: selectedItem?.schema }; diff --git a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts index 82eecb1127..572c046982 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts @@ -116,7 +116,7 @@ export class PredictWizard extends ModelViewBase { } else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) { return await this.azureModelsComponent.getDownloadedModel(); } else if (this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) { - return await this.modelBrowsePage.registeredModelsComponent.getDownloadedModel(); + return await this.modelBrowsePage.registeredModelsComponent.modelTable?.getDownloadedModel(); } return undefined; } diff --git a/extensions/machine-learning-services/src/views/models/tableSelectionComponent.ts b/extensions/machine-learning-services/src/views/models/tableSelectionComponent.ts new file mode 100644 index 0000000000..537c687f2e --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/tableSelectionComponent.ts @@ -0,0 +1,213 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as vscode from 'vscode'; +import { ModelViewBase } from './modelViewBase'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as constants from '../../common/constants'; +import { IDataComponent } from '../interfaces'; +import { DatabaseTable } from '../../prediction/interfaces'; + +/** + * View to render filters to pick an azure resource + */ +export class TableSelectionComponent extends ModelViewBase implements IDataComponent { + + private _form: azdata.FormContainer | undefined; + private _databases: azdata.DropDownComponent | undefined; + private _selectedTableName: string = ''; + private _tables: azdata.DropDownComponent | undefined; + private _dbNames: string[] = []; + private _tableNames: DatabaseTable[] = []; + private _dbTableComponent: azdata.FlexContainer | undefined; + private tableMaxLength = this.componentMaxLength * 2 + 70; + private _onSelectedChanged: vscode.EventEmitter = new vscode.EventEmitter(); + public readonly onSelectedChanged: vscode.Event = this._onSelectedChanged.event; + + /** + * Creates a new view + */ + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _editable: boolean) { + super(apiWrapper, parent.root, parent); + } + + /** + * Register components + * @param modelBuilder model builder + */ + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + this._databases = modelBuilder.dropDown().withProperties({ + width: this.componentMaxLength, + editable: this._editable, + fireOnTextChange: this._editable + }).component(); + this._tables = modelBuilder.dropDown().withProperties({ + width: this.componentMaxLength, + editable: this._editable, + fireOnTextChange: this._editable + }).component(); + + this._databases.onValueChanged(async () => { + await this.onDatabaseSelected(); + }); + + this._tables.onValueChanged(async (value) => { + // There's an issue with dropdown doesn't set the value in editable mode. this is the workaround + + if (this._tables && value) { + this._selectedTableName = this._editable ? value : value.selected; + } + await this.onTableSelected(); + }); + + const databaseForm = modelBuilder.formContainer().withFormItems([{ + title: constants.columnDatabase, + component: this._databases, + }]).withLayout({ + padding: '0px' + }).component(); + const tableForm = modelBuilder.formContainer().withFormItems([{ + title: constants.columnTable, + component: this._tables + }]).withLayout({ + padding: '0px' + }).component(); + this._dbTableComponent = modelBuilder.flexContainer().withItems([ + databaseForm, + tableForm + ], { + flex: '0 0 auto', + CSSStyles: { + 'align-items': 'flex-start' + } + }).withLayout({ + flexFlow: 'row', + justifyContent: 'space-between', + width: this.tableMaxLength + }).component(); + + this._form = modelBuilder.formContainer().withFormItems([{ + title: '', + component: this._dbTableComponent + }]).component(); + return this._form; + } + + public addComponents(formBuilder: azdata.FormBuilder) { + if (this._databases && this._tables) { + formBuilder.addFormItems([{ + title: constants.databaseName, + component: this._databases + }, { + title: constants.tableName, + component: this._tables + }]); + } + } + + public removeComponents(formBuilder: azdata.FormBuilder) { + if (this._databases && this._tables) { + formBuilder.removeFormItem({ + title: constants.databaseName, + component: this._databases + }); + formBuilder.removeFormItem({ + title: constants.tableName, + component: this._tables + }); + } + } + + /** + * Returns the created component + */ + public get component(): azdata.Component | undefined { + return this._dbTableComponent; + } + + /** + * Returns selected data + */ + public get data(): DatabaseTable | undefined { + return this.databaseTable; + } + + /** + * loads data in the components + */ + public async loadData(): Promise { + this._dbNames = await this.listDatabaseNames(); + if (this._databases && this._dbNames && this._dbNames.length > 0) { + this._databases.values = this._dbNames; + if (this.importTable) { + this._databases.value = this.importTable.databaseName; + } else { + this._databases.value = this._dbNames[0]; + } + } + await this.onDatabaseSelected(); + } + + /** + * refreshes the view + */ + public async refresh(): Promise { + await this.loadData(); + } + + private async onDatabaseSelected(): Promise { + this._tableNames = await this.listTableNames(this.databaseName || ''); + if (this._tables && this._tableNames && this._tableNames.length > 0) { + this._tables.values = this._tableNames.map(t => this.getTableFullName(t)); + if (this.importTable) { + const selectedTable = this._tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema); + if (selectedTable) { + this._selectedTableName = this.getTableFullName(selectedTable); + this._tables.value = this.getTableFullName(selectedTable); + } else { + this._selectedTableName = this.getTableFullName(this._tableNames[0]); + } + } else { + this._selectedTableName = this.getTableFullName(this._tableNames[0]); + } + this._tables.value = this._selectedTableName; + } else if (this._tables) { + this._tables.values = []; + this._tables.value = ''; + } + await this.onTableSelected(); + } + + private getTableFullName(table: DatabaseTable): string { + return `${table.schema}.${table.tableName}`; + } + + private async onTableSelected(): Promise { + this._onSelectedChanged.fire(); + } + + private get databaseName(): string | undefined { + return this._databases?.value; + } + + private get databaseTable(): DatabaseTable { + let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._selectedTableName); + if (!selectedItem) { + const value = this._selectedTableName; + const parts = value ? value.split('.') : undefined; + selectedItem = { + databaseName: this.databaseName, + tableName: parts && parts.length > 1 ? parts[1] : value, + schema: parts && parts.length > 1 ? parts[0] : 'dbo', + }; + } + return { + databaseName: this.databaseName, + tableName: selectedItem?.tableName, + schema: selectedItem?.schema + }; + } +} diff --git a/extensions/machine-learning-services/src/views/widgets/dashboardWidget.ts b/extensions/machine-learning-services/src/views/widgets/dashboardWidget.ts index 4570f8bf96..1d25d6d186 100644 --- a/extensions/machine-learning-services/src/views/widgets/dashboardWidget.ts +++ b/extensions/machine-learning-services/src/views/widgets/dashboardWidget.ts @@ -369,7 +369,7 @@ export class DashboardWidget { light: this.asAbsolutePath('images/makePredictions.svg'), }, link: '', - command: constants.mlImportModelCommand + command: constants.mlManageModelsCommand }; const importModelsButton = this.createTaskButton(view, importMetadata); const notebookMetadata: IActionMetadata = { diff --git a/extensions/machine-learning-services/src/views/wizardView.ts b/extensions/machine-learning-services/src/views/wizardView.ts index 49b1adef5a..3067191cdc 100644 --- a/extensions/machine-learning-services/src/views/wizardView.ts +++ b/extensions/machine-learning-services/src/views/wizardView.ts @@ -75,7 +75,7 @@ export class WizardView extends MainViewBase { } public async validate(pageInfo: azdata.window.WizardPageChangeInfo): Promise { - if (pageInfo.lastPage !== undefined) { + if (pageInfo?.lastPage !== undefined) { let idxLast = pageInfo.lastPage; let lastPage = this._pages[idxLast]; if (lastPage && lastPage.validate) { @@ -86,16 +86,23 @@ export class WizardView extends MainViewBase { } private async onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) { - let idxLast = pageInfo.lastPage; - let lastPage = this._pages[idxLast]; - if (lastPage && lastPage.onLeave) { - await lastPage.onLeave(); + if (pageInfo?.lastPage !== undefined) { + let idxLast = pageInfo.lastPage; + let lastPage = this._pages[idxLast]; + if (lastPage && lastPage.onLeave) { + await lastPage.onLeave(); + } } - let idx = pageInfo.newPage; - let page = this._pages[idx]; - if (page && page.onEnter) { - await page.onEnter(); + if (pageInfo?.newPage !== undefined) { + let idx = pageInfo.newPage; + let page = this._pages[idx]; + if (page && page.onEnter) { + if (this._wizard && this._wizard.pages.length > idx) { + this._wizard.pages[idx].title = page.title; + } + await page.onEnter(); + } } }