diff --git a/extensions/machine-learning-services/images/arrow.svg b/extensions/machine-learning-services/images/arrow.svg new file mode 100644 index 0000000000..91626b1129 --- /dev/null +++ b/extensions/machine-learning-services/images/arrow.svg @@ -0,0 +1,3 @@ + + + diff --git a/extensions/machine-learning-services/src/common/constants.ts b/extensions/machine-learning-services/src/common/constants.ts index e8d55ba9be..6711463a77 100644 --- a/extensions/machine-learning-services/src/common/constants.ts +++ b/extensions/machine-learning-services/src/common/constants.ts @@ -60,7 +60,7 @@ export function confirmInstallPythonPackages(packages: string): string { export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ..."); export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed."); export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); } -export const packageManagerNoConnection = localize('mls.packageManager.NoConnection', "No connection selected"); +export const noConnectionError = localize('mls.packageManager.NoConnection', "No connection selected"); export const notebookExtensionNotLoaded = localize('mls.notebookExtensionNotLoaded', "Notebook extension is not loaded"); export const mssqlExtensionNotLoaded = localize('mls.mssqlExtensionNotLoaded', "MSSQL extension is not loaded"); export const mlsEnabledMessage = localize('mls.enabledMessage', "Machine Learning Services Enabled"); @@ -74,6 +74,8 @@ export const mlsExternalExecuteScriptTitle = localize('mls.externalExecuteScript export const mlsPythonLanguageTitle = localize('mls.pythonLanguageTitle', "Python"); export const mlsRLanguageTitle = localize('mls.rLanguageTitle', "R"); export const downloadError = localize('mls.downloadError', "Error while downloading"); +export function invalidModelIdError(modelUrl: string | undefined): string { return localize('mls.invalidModelIdError', "Invalid model id. model url: {0}", modelUrl || ''); } +export function noArtifactError(modelUrl: string | undefined): string { return localize('mls.noArtifactError', "Model doesn't have any artifact. model url: {0}", modelUrl || ''); } export const downloadingProgress = localize('mls.downloadingProgress', "Downloading"); export const pythonConfigError = localize('mls.pythonConfigError', "Python executable is not configured"); export const rConfigError = localize('mls.rConfigError', "R executable is not configured"); @@ -119,12 +121,15 @@ export const modelCreated = localize('models.created', "Date Created"); export const modelVersion = localize('models.version', "Version"); export const browseModels = localize('models.browseButton', "..."); export const azureAccount = localize('models.azureAccount', "Azure account"); -export const columnDatabase = localize('predict.columnDatabase', "Database"); -export const columnTable = localize('predict.columnTable', "Table"); -export const inputColumns = localize('predict.inputColumns', "Input columns"); -export const outputColumns = localize('predict.outputColumns', "Output column"); -export const columnName = localize('predict.columnName', "Name"); -export const inputName = localize('predict.inputName', "Input Name"); +export const columnDatabase = localize('predict.columnDatabase', "Target database"); +export const columnTable = localize('predict.columnTable', "Target table"); +export const inputColumns = localize('predict.inputColumns', "Model input mapping"); +export const outputColumns = localize('predict.outputColumns', "Model output"); +export const columnName = localize('predict.columnName', "Target columns"); +export const dataTypeName = localize('predict.dataTypeName', "Type"); +export const displayName = localize('predict.displayName', "Display name"); +export const inputName = localize('predict.inputName', "Required model input features"); +export const outputName = localize('predict.outputName', "Name"); export const azureSubscription = localize('models.azureSubscription', "Azure subscription"); export const azureGroup = localize('models.azureGroup', "Azure resource group"); export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace"); @@ -134,7 +139,7 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model export const localModelsTitle = localize('models.localModelsTitle', "Local models"); export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location"); export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Enter model source details"); -export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Select input columns"); +export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input"); export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details"); export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file"); export const currentModelsTitle = localize('models.currentModelsTitle', "Models"); @@ -156,6 +161,8 @@ export const invalidModelToSelectError = localize('models.invalidModelToSelectEr export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model"); +export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'"); +export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported"); diff --git a/extensions/machine-learning-services/src/common/processService.ts b/extensions/machine-learning-services/src/common/processService.ts index f6a95031f4..9dc050cf4c 100644 --- a/extensions/machine-learning-services/src/common/processService.ts +++ b/extensions/machine-learning-services/src/common/processService.ts @@ -23,16 +23,19 @@ export class ProcessService { scriptExecution.stdin.end(); // Add listeners to print stdout and stderr if an output channel was provided - if (outputChannel) { - scriptExecution.stdout.on('data', data => { + + scriptExecution.stdout.on('data', data => { + if (outputChannel) { this.outputDataChunk(data, outputChannel, ' stdout: '); - output = output + data.toString(); - }); - scriptExecution.stderr.on('data', data => { + } + output = output + data.toString(); + }); + scriptExecution.stderr.on('data', data => { + if (outputChannel) { this.outputDataChunk(data, outputChannel, ' stderr: '); - output = output + data.toString(); - }); - } + } + output = output + data.toString(); + }); scriptExecution.on('exit', (code) => { if (timer) { diff --git a/extensions/machine-learning-services/src/common/utils.ts b/extensions/machine-learning-services/src/common/utils.ts index 483eb9440c..96c8697c38 100644 --- a/extensions/machine-learning-services/src/common/utils.ts +++ b/extensions/machine-learning-services/src/common/utils.ts @@ -22,7 +22,17 @@ export async function execCommandOnTempFile(content: string, command: (filePa return result; } finally { - await fs.promises.unlink(tempFilePath); + await deleteFile(tempFilePath); + } +} + +/** + * Deletes a file + * @param filePath file path + */ +export async function deleteFile(filePath: string) { + if (filePath) { + await fs.promises.unlink(filePath); } } @@ -215,7 +225,7 @@ export function getRegisteredModelsThreePartsName(config: Config) { const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName); const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName); const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName); - return `[${dbName}].${schema}.[${tableName}]`; + return `[${dbName}].[${schema}].[${tableName}]`; } /** @@ -227,3 +237,14 @@ export function getRegisteredModelsTowPartsName(config: Config) { const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName); return `[${schema}].[${tableName}]`; } + +/** + * Write a file using a hex string + * @param content file content + */ +export async function writeFileFromHex(content: string): Promise { + content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content; + const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`); + await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex')); + return tempFilePath; +} diff --git a/extensions/machine-learning-services/src/controllers/mainController.ts b/extensions/machine-learning-services/src/controllers/mainController.ts index 390dfb3d1e..a9cc6babf6 100644 --- a/extensions/machine-learning-services/src/controllers/mainController.ts +++ b/extensions/machine-learning-services/src/controllers/mainController.ts @@ -18,9 +18,9 @@ import { HttpClient } from '../common/httpClient'; import { LanguageController } from '../views/externalLanguages/languageController'; import { LanguageService } from '../externalLanguage/languageService'; import { ModelManagementController } from '../views/models/modelManagementController'; -import { RegisteredModelService } from '../modelManagement/registeredModelService'; +import { DeployedModelService } from '../modelManagement/deployedModelService'; import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService'; -import { ModelImporter } from '../modelManagement/modelImporter'; +import { ModelPythonClient } from '../modelManagement/modelPythonClient'; import { PredictService } from '../prediction/predictService'; /** @@ -100,11 +100,11 @@ export default class MainController implements vscode.Disposable { let mssqlService = await this.getLanguageExtensionService(); let languagesModel = new LanguageService(this._apiWrapper, mssqlService); let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel); - let modelImporter = new ModelImporter(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager); + let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager); // Model Management // - let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter); + let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter); let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel); let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config); let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, diff --git a/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts b/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts index b63bead25e..47f24c0532 100644 --- a/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts +++ b/extensions/machine-learning-services/src/modelManagement/azureModelRegistryService.ts @@ -28,10 +28,16 @@ import * as utils from '../common/utils'; */ export class AzureModelRegistryService { + private _amlClient: AzureMachineLearningWorkspaces | undefined; + private _modelClient: WorkspaceModels | undefined; /** - * + * Creates new service */ - constructor(private _apiWrapper: ApiWrapper, private _config: Config, private _httpClient: HttpClient, private _outputChannel: vscode.OutputChannel) { + constructor( + private _apiWrapper: ApiWrapper, + private _config: Config, + private _httpClient: HttpClient, + private _outputChannel: vscode.OutputChannel) { } /** @@ -120,10 +126,18 @@ export class AzureModelRegistryService { return downloadedFilePath; } + public set AzureMachineLearningClient(value: AzureMachineLearningWorkspaces) { + this._amlClient = value; + } + + public set ModelClient(value: WorkspaceModels) { + this._modelClient = value; + } + /** - * Installs dependencies for the extension + * Execute the background task to download the artifact */ - public async execDownloadArtifactTask(downloadUrl: string): Promise { + private async execDownloadArtifactTask(downloadUrl: string): Promise { let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true); return results && results.length > 0 ? results[0] : constants.noResultError; } @@ -139,15 +153,14 @@ export class AzureModelRegistryService { try { for (const tenant of account.properties.tenants) { - const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement); - const token = tokens[tenant.id].token; - const tokenType = tokens[tenant.id].tokenType; - const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id); + const client = await this.getAmlClient(account, subscription, tenant); let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription(); - resources.push(...result); + if (result) { + resources.push(...result); + } } } catch (error) { - + console.log(error); } return resources; } @@ -161,9 +174,11 @@ export class AzureModelRegistryService { for (const tenant of account.properties.tenants) { try { - let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl); - const client = await this.getClient(baseUri, account, subscription, tenant); - let modelsClient = new WorkspaceModels(client); + let options: AzureMachineLearningWorkspacesOptions = { + baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl) + }; + const client = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion); + let modelsClient = this.getModelClient(client); resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || '')); } catch (error) { @@ -182,22 +197,28 @@ export class AzureModelRegistryService { client: AzureMachineLearningWorkspaces): Promise { const modelId = this.getModelId(model); - let modelsClient = new Assets(client); - return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId); + if (modelId) { + let modelsClient = new Assets(client); + return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId); + } else { + throw Error(constants.invalidModelIdError(model.url)); + } } - public async getAssetArtifactsDownloadLinks( + private async getAssetArtifactsDownloadLinks( account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource, workspace: Workspace, model: WorkspaceModel, tenant: any): Promise { - let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl); - const modelManagementClient = await this.getClient(baseUri, account, subscription, tenant); + let options: AzureMachineLearningWorkspacesOptions = { + baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl) + }; + const modelManagementClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion); const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient); - baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl); - const experienceClient = await this.getClient(baseUri, account, subscription, tenant); + options.baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl); + const experienceClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion); const artifactClient = new Artifacts(experienceClient); let downloadLinks: string[] = []; if (asset && asset.artifacts) { @@ -230,17 +251,19 @@ export class AzureModelRegistryService { downloadLinkPromises.push(promise); } } - try { downloadLinks = await Promise.all(downloadLinkPromises); } catch (rejectedPromiseError) { return rejectedPromiseError; } + return downloadLinks; + + } else { + throw Error(constants.noArtifactError(model.url)); } - return downloadLinks; } - public getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined { + private getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined { const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/; if (idOrPrefix) { const parts = artifactRegex.exec(idOrPrefix); @@ -263,16 +286,35 @@ export class AzureModelRegistryService { return baseUri; } - private async getClient(baseUri: string, account: azdata.Account, subscription: azureResource.AzureResourceSubscription, tenant: any): Promise { - const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement); - const token = tokens[tenant.id].token; - const tokenType = tokens[tenant.id].tokenType; - const options: AzureMachineLearningWorkspacesOptions = { - baseUri: baseUri - }; - const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options); - client.apiVersion = this._config.amlApiVersion; - return client; + private getModelClient(amlClient: AzureMachineLearningWorkspaces) { + return this._modelClient ?? new WorkspaceModels(amlClient); + } + + private async getAmlClient( + account: azdata.Account, + subscription: azureResource.AzureResourceSubscription, + tenant: any, + options: AzureMachineLearningWorkspacesOptions | undefined = undefined, + apiVersion: string | undefined = undefined): Promise { + if (this._amlClient) { + return this._amlClient; + } else { + const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement); + let token: string = ''; + let tokenType: string | undefined = undefined; + if (tokens && tenant.id in tokens) { + const tokenForId = tokens[tenant.id]; + if (tokenForId) { + token = tokenForId.token; + tokenType = tokenForId.tokenType; + } + } + const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options); + if (apiVersion) { + client.apiVersion = apiVersion; + } + return client; + } } private getModelId(model: WorkspaceModel): string { diff --git a/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts similarity index 73% rename from extensions/machine-learning-services/src/modelManagement/registeredModelService.ts rename to extensions/machine-learning-services/src/modelManagement/deployedModelService.ts index 6720bc2563..7bee455235 100644 --- a/extensions/machine-learning-services/src/modelManagement/registeredModelService.ts +++ b/extensions/machine-learning-services/src/modelManagement/deployedModelService.ts @@ -9,73 +9,85 @@ import { ApiWrapper } from '../common/apiWrapper'; import * as utils from '../common/utils'; import { Config } from '../configurations/config'; import { QueryRunner } from '../common/queryRunner'; -import { RegisteredModel, RegisteredModelDetails } from './interfaces'; -import { ModelImporter } from './modelImporter'; +import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces'; +import { ModelPythonClient } from './modelPythonClient'; import * as constants from '../common/constants'; /** - * Service to registered models + * Service to deployed models */ -export class RegisteredModelService { +export class DeployedModelService { /** - * + * Creates new instance */ constructor( private _apiWrapper: ApiWrapper, private _config: Config, private _queryRunner: QueryRunner, - private _modelImporter: ModelImporter) { + private _modelClient: ModelPythonClient) { } - public async getRegisteredModels(): Promise { + /** + * Returns deployed models + */ + public async getDeployedModels(): Promise { let connection = await this.getCurrentConnection(); let list: RegisteredModel[] = []; if (connection) { let query = this.getConfigureQuery(connection.databaseName); await this._queryRunner.safeRunQuery(connection, query); - query = this.registeredModelsQuery(); + query = this.getDeployedModelsQuery(); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { result.rows.forEach(row => { list.push(this.loadModelData(row)); }); } + } else { + throw Error(constants.noConnectionError); } return list; } - private loadModelData(row: azdata.DbCellValue[]): RegisteredModel { - return { - id: +row[0].displayValue, - artifactName: row[1].displayValue, - title: row[2].displayValue, - description: row[3].displayValue, - version: row[4].displayValue, - created: row[5].displayValue - }; - } - - public async updateModel(model: RegisteredModel): Promise { + /** + * Downloads model + * @param model model object + */ + public async downloadModel(model: RegisteredModel): Promise { let connection = await this.getCurrentConnection(); - let updatedModel: RegisteredModel | undefined = undefined; if (connection) { - const query = this.getUpdateModelScript(connection.databaseName, model); + const query = this.getModelContentQuery(model); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { - const row = result.rows[0]; - updatedModel = this.loadModelData(row); + const content = result.rows[0][0].displayValue; + return await utils.writeFileFromHex(content); + } else { + throw Error(constants.invalidModelToSelectError); } + } else { + throw Error(constants.noConnectionError); } - return updatedModel; } - public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) { + /** + * Loads model parameters + */ + public async loadModelParameters(filePath: string): Promise { + return await this._modelClient.loadModelParameters(filePath); + } + + /** + * Deploys local model + * @param filePath model file path + * @param details model details + */ + public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined) { let connection = await this.getCurrentConnection(); if (connection) { - let currentModels = await this.getRegisteredModels(); - await this._modelImporter.registerModel(connection, filePath); - let updatedModels = await this.getRegisteredModels(); + let currentModels = await this.getDeployedModels(); + await this._modelClient.deployModel(connection, filePath); + let updatedModels = await this.getDeployedModels(); if (details && updatedModels.length >= currentModels.length + 1) { updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0); const addedModel = updatedModels[updatedModels.length - 1]; @@ -92,16 +104,40 @@ export class RegisteredModelService { } } } + private loadModelData(row: azdata.DbCellValue[]): RegisteredModel { + return { + id: +row[0].displayValue, + artifactName: row[1].displayValue, + title: row[2].displayValue, + description: row[3].displayValue, + version: row[4].displayValue, + created: row[5].displayValue + }; + } + + private async updateModel(model: RegisteredModel): Promise { + let connection = await this.getCurrentConnection(); + let updatedModel: RegisteredModel | undefined = undefined; + if (connection) { + const query = this.getUpdateModelQuery(connection.databaseName, model); + let result = await this._queryRunner.safeRunQuery(connection, query); + if (result?.rows && result.rows.length > 0) { + const row = result.rows[0]; + updatedModel = this.loadModelData(row); + } + } + return updatedModel; + } private async getCurrentConnection(): Promise { return await this._apiWrapper.getCurrentConnection(); } - private getConfigureQuery(currentDatabaseName: string): string { - return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable()); + public getConfigureQuery(currentDatabaseName: string): string { + return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.getConfigureTableQuery()); } - private registeredModelsQuery(): string { + public getDeployedModelsQuery(): string { return ` SELECT artifact_id, artifact_name, name, description, version, created FROM ${utils.getRegisteredModelsThreePartsName(this._config)} @@ -116,7 +152,7 @@ export class RegisteredModelService { * @param databaseName * @param tableName */ - private configureTable(): string { + public getConfigureTableQuery(): string { let databaseName = this._config.registeredModelDatabaseName; let tableName = this._config.registeredModelTableName; let schemaName = this._config.registeredModelTableSchemaName; @@ -171,7 +207,7 @@ export class RegisteredModelService { `; } - private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string { + public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string { let updateScript = ` UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)} SET @@ -187,4 +223,12 @@ export class RegisteredModelService { WHERE artifact_id = ${model.id}; `; } + + public getModelContentQuery(model: RegisteredModel): string { + return ` + SELECT artifact_content + FROM ${utils.getRegisteredModelsThreePartsName(this._config)} + WHERE artifact_id = ${model.id}; + `; + } } diff --git a/extensions/machine-learning-services/src/modelManagement/interfaces.ts b/extensions/machine-learning-services/src/modelManagement/interfaces.ts index 212c3adc34..f827bffc34 100644 --- a/extensions/machine-learning-services/src/modelManagement/interfaces.ts +++ b/extensions/machine-learning-services/src/modelManagement/interfaces.ts @@ -53,6 +53,16 @@ export interface RegisteredModel extends RegisteredModelDetails { artifactName: string; } +export interface ModelParameter { + name: string; + type: string; +} + +export interface ModelParameters { + inputs: ModelParameter[], + outputs: ModelParameter[] +} + /** * An interface representing registered model */ diff --git a/extensions/machine-learning-services/src/modelManagement/modelImporter.ts b/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts similarity index 53% rename from extensions/machine-learning-services/src/modelManagement/modelImporter.ts rename to extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts index ad00576055..1b5022b554 100644 --- a/extensions/machine-learning-services/src/modelManagement/modelImporter.ts +++ b/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts @@ -13,33 +13,89 @@ import * as utils from '../common/utils'; import { PackageManager } from '../packageManagement/packageManager'; import * as constants from '../common/constants'; import * as os from 'os'; +import { ModelParameters } from './interfaces'; /** - * Service to import model to database + * Python client for ONNX models */ -export class ModelImporter { +export class ModelPythonClient { /** - * + * Creates new instance */ constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) { } - public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise { + /** + * Deploys models in the SQL database using mlflow + * @param connection + * @param modelPath + */ + public async deployModel(connection: azdata.connection.ConnectionProfile, modelPath: string): Promise { await this.installDependencies(); - await this.executeScripts(connection, modelFolderPath); + await this.executeDeployScripts(connection, modelPath); } /** - * Installs dependencies for model importer + * Installs dependencies for python client */ - public async installDependencies(): Promise { + private async installDependencies(): Promise { await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [ this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true); } - protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise { + /** + * + * @param modelPath Loads model parameters + */ + public async loadModelParameters(modelPath: string): Promise { + await this.installDependencies(); + return await this.executeModelParametersScripts(modelPath); + } + private async executeModelParametersScripts(modelFolderPath: string): Promise { + modelFolderPath = utils.makeLinuxPath(modelFolderPath); + + let scripts: string[] = [ + 'import onnx', + 'import json', + `onnx_model_path = '${modelFolderPath}'`, + `onnx_model = onnx.load_model(onnx_model_path)`, + `type_map = { + onnx.TensorProto.DataType.FLOAT: 'real', + onnx.TensorProto.DataType.UINT8: 'tinyint', + onnx.TensorProto.DataType.INT16: 'smallint', + onnx.TensorProto.DataType.INT32: 'int', + onnx.TensorProto.DataType.INT64: 'bigint', + onnx.TensorProto.DataType.STRING: 'varchar(MAX)', + onnx.TensorProto.DataType.DOUBLE: 'float'}`, + `parameters = { + "inputs": [], + "outputs": [] + }`, + `def addParameters(list, paramType): + for id, p in enumerate(list): + p_type = '' + + if p.type.tensor_type.elem_type in type_map: + p_type = type_map[p.type.tensor_type.elem_type] + + parameters[paramType].append({ + 'name': p.name, + 'type': p_type + })`, + + 'addParameters(onnx_model.graph.input, "inputs")', + 'addParameters(onnx_model.graph.output, "outputs")', + 'print(json.dumps(parameters))' + ]; + let pythonExecutable = this._config.pythonExecutable; + let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined); + let parametersJson = JSON.parse(output); + return Object.assign({}, parametersJson); + } + + private async executeDeployScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise { let home = utils.makeLinuxPath(os.homedir()); modelFolderPath = utils.makeLinuxPath(modelFolderPath); diff --git a/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts b/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts index 703c2d5d4c..90870aa84a 100644 --- a/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts +++ b/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts @@ -30,7 +30,7 @@ export abstract class SqlPackageManageProviderBase { if (connection) { return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`; } - return constants.packageManagerNoConnection; + return constants.noConnectionError; } protected async getCurrentConnection(): Promise { diff --git a/extensions/machine-learning-services/src/prediction/interfaces.ts b/extensions/machine-learning-services/src/prediction/interfaces.ts index 5fcb789edd..2274a0d8b1 100644 --- a/extensions/machine-learning-services/src/prediction/interfaces.ts +++ b/extensions/machine-learning-services/src/prediction/interfaces.ts @@ -3,10 +3,13 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -export interface PredictColumn { - name: string; +export interface TableColumn { + columnName: string; dataType?: string; - displayName?: string; +} + +export interface PredictColumn extends TableColumn { + paramName?: string; } export interface DatabaseTable { diff --git a/extensions/machine-learning-services/src/prediction/predictService.ts b/extensions/machine-learning-services/src/prediction/predictService.ts index 31b714c82d..eebb89bd29 100644 --- a/extensions/machine-learning-services/src/prediction/predictService.ts +++ b/extensions/machine-learning-services/src/prediction/predictService.ts @@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper'; import { QueryRunner } from '../common/queryRunner'; import * as utils from '../common/utils'; import { RegisteredModel } from '../modelManagement/interfaces'; -import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces'; +import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces'; import { Config } from '../configurations/config'; /** @@ -98,15 +98,18 @@ export class PredictService { *Returns list of column names of a database * @param databaseTable table info */ - public async getTableColumnsList(databaseTable: DatabaseTable): Promise { + public async getTableColumnsList(databaseTable: DatabaseTable): Promise { let connection = await this.getCurrentConnection(); - let list: string[] = []; + let list: TableColumn[] = []; if (connection && databaseTable.databaseName) { const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable)); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { result.rows.forEach(row => { - list.push(row[0].displayValue); + list.push({ + columnName: row[0].displayValue, + dataType: row[1].displayValue + }); }); } } @@ -119,7 +122,7 @@ export class PredictService { private getTableColumnsScript(databaseTable: DatabaseTable): string { return ` -SELECT COLUMN_NAME,* +SELECT COLUMN_NAME,DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}' AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}' @@ -149,14 +152,14 @@ DECLARE @model VARBINARY(max) = ( WITH predict_input AS ( SELECT TOP 1000 - ${this.getColumnNames(columns, 'pi')} + ${this.getInputColumnNames(columns, 'pi')} FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi ) SELECT -${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')} +${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')} FROM PREDICT(MODEL = @model, DATA = predict_input) WITH ( - ${this.getColumnTypes(outputColumns)} + ${this.getOutputParameters(outputColumns)} ) AS p `; } @@ -170,33 +173,43 @@ WITH ( WITH predict_input AS ( SELECT TOP 1000 - ${this.getColumnNames(columns, 'pi')} + ${this.getInputColumnNames(columns, 'pi')} FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi ) SELECT -${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')} +${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')} FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input) WITH ( - ${this.getColumnTypes(outputColumns)} + ${this.getOutputParameters(outputColumns)} ) AS p `; } - private getColumnNames(columns: PredictColumn[], tableName: string) { - return columns.map(c => { - return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`; - }).join(',\n'); - } - private getInputColumnNames(columns: PredictColumn[], tableName: string) { return columns.map(c => { - return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`; + return this.getColumnName(tableName, c.paramName || '', c.columnName); }).join(',\n'); } - private getColumnTypes(columns: PredictColumn[]) { + private getOutputColumnNames(columns: PredictColumn[], tableName: string) { return columns.map(c => { - return `${c.name} ${c.dataType}`; + return this.getColumnName(tableName, c.columnName, c.paramName || ''); + }).join(',\n'); + } + + private getColumnName(tableName: string, columnName: string, displayName: string) { + return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`; + } + + private getPredictColumnNames(columns: PredictColumn[], tableName: string) { + return columns.map(c => { + return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`; + }).join(',\n'); + } + + private getOutputParameters(columns: PredictColumn[]) { + return columns.map(c => { + return `${c.paramName} ${c.dataType}`; }).join(',\n'); } } diff --git a/extensions/machine-learning-services/src/test/modelManagement/azureModelRegistryService.test.ts b/extensions/machine-learning-services/src/test/modelManagement/azureModelRegistryService.test.ts new file mode 100644 index 0000000000..e2a92f9169 --- /dev/null +++ b/extensions/machine-learning-services/src/test/modelManagement/azureModelRegistryService.test.ts @@ -0,0 +1,232 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as vscode from 'vscode'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as TypeMoq from 'typemoq'; +import * as should from 'should'; +import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService'; +import { Config } from '../../configurations/config'; +import { HttpClient } from '../../common/httpClient'; +import { azureResource } from '../../typings/azure-resource'; + +import * as utils from '../utils'; +import { Workspace, WorkspacesListByResourceGroupResponse } from '@azure/arm-machinelearningservices/esm/models'; +import { WorkspaceModel, AssetsQueryByIdResponse, Asset, GetArtifactContentInformation2Response } from '../../modelManagement/interfaces'; +import { AzureMachineLearningWorkspaces, Workspaces } from '@azure/arm-machinelearningservices'; +import { WorkspaceModels } from '../../modelManagement/workspacesModels'; + +interface TestContext { + + apiWrapper: TypeMoq.IMock; + config: TypeMoq.IMock; + httpClient: TypeMoq.IMock; + outputChannel: vscode.OutputChannel; + op: azdata.BackgroundOperation; + accounts: azdata.Account[]; + subscriptions: azureResource.AzureResourceSubscription[]; + groups: azureResource.AzureResourceResourceGroup[]; + workspaces: Workspace[]; + models: WorkspaceModel[]; + client: TypeMoq.IMock; + workspacesClient: TypeMoq.IMock; + modelClient: TypeMoq.IMock; +} + +function createContext(): TestContext { + const context = utils.createContext(); + const workspaces = TypeMoq.Mock.ofType(Workspaces); + const credentials = { + signRequest: () => { + return Promise.resolve(undefined!!); + } + }; + const client = TypeMoq.Mock.ofInstance(new AzureMachineLearningWorkspaces(credentials, 'subscription')); + client.setup(x => x.apiVersion).returns(() => '20180101'); + + return { + apiWrapper: TypeMoq.Mock.ofType(ApiWrapper), + config: TypeMoq.Mock.ofType(Config), + httpClient: TypeMoq.Mock.ofType(HttpClient), + outputChannel: context.outputChannel, + op: context.op, + accounts: [ + { + key: { + providerId: '', + accountId: 'a1' + }, + displayInfo: { + contextualDisplayName: '', + accountType: '', + displayName: 'a1', + userId: 'a1' + }, + properties: + { + tenants: [ + { + id: '1', + } + ] + } + , + isStale: true + } + ], + subscriptions: [ + { + name: 's1', + id: 's1' + } + ], + groups: [ + { + name: 'g1', + id: 'g1' + } + ], + workspaces: [{ + name: 'w1', + id: 'w1' + } + ], + models: [ + { + name: 'm1', + id: 'm1', + url: 'aml://asset/test.test' + } + ], + client: client, + workspacesClient: workspaces, + modelClient: TypeMoq.Mock.ofInstance(new WorkspaceModels(client.object)) + }; +} + +describe('AzureModelRegistryService', () => { + it('getAccounts should return the list of accounts successfully', async function (): Promise { + let testContext = createContext(); + const accounts = testContext.accounts; + let service = new AzureModelRegistryService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.httpClient.object, + testContext.outputChannel); + testContext.apiWrapper.setup(x => x.getAllAccounts()).returns(() => Promise.resolve(accounts)); + let actual = await service.getAccounts(); + should.deepEqual(actual, testContext.accounts); + }); + + it('getSubscriptions should return the list of subscriptions successfully', async function (): Promise { + let testContext = createContext(); + const expected = testContext.subscriptions; + let service = new AzureModelRegistryService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.httpClient.object, + testContext.outputChannel); + testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ subscriptions: expected, errors: [] })); + let actual = await service.getSubscriptions(testContext.accounts[0]); + should.deepEqual(actual, expected); + }); + + it('getGroups should return the list of groups successfully', async function (): Promise { + let testContext = createContext(); + const expected = testContext.groups; + let service = new AzureModelRegistryService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.httpClient.object, + testContext.outputChannel); + testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ resourceGroups: expected, errors: [] })); + let actual = await service.getGroups(testContext.accounts[0], testContext.subscriptions[0]); + should.deepEqual(actual, expected); + }); + + it('getWorkspaces should return the list of workspaces successfully', async function (): Promise { + let testContext = createContext(); + const response: WorkspacesListByResourceGroupResponse = Object.assign(new Array(...testContext.workspaces), { + _response: undefined! + }); + const expected = testContext.workspaces; + testContext.workspacesClient.setup(x => x.listByResourceGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(response)); + testContext.workspacesClient.setup(x => x.listBySubscription()).returns(() => Promise.resolve(response)); + testContext.client.setup(x => x.workspaces).returns(() => testContext.workspacesClient.object); + let service = new AzureModelRegistryService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.httpClient.object, + testContext.outputChannel); + + + service.AzureMachineLearningClient = testContext.client.object; + let actual = await service.getWorkspaces(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0]); + should.deepEqual(actual, expected); + }); + + it('getModels should return the list of models successfully', async function (): Promise { + let testContext = createContext(); + testContext.config.setup(x => x.amlApiVersion).returns(() => '2018'); + testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url'); + const expected = testContext.models; + let service = new AzureModelRegistryService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.httpClient.object, + testContext.outputChannel); + service.AzureMachineLearningClient = testContext.client.object; + service.ModelClient = testContext.modelClient.object; + testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models)); + let actual = await service.getModels(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0]); + should.deepEqual(actual, expected); + }); + + it('downloadModel should download model artifact successfully', async function (): Promise { + let testContext = createContext(); + const asset: Asset = + { + id: '1', + name: 'asset', + artifacts: [ + { + id: '/1/2/3/4/5/' + } + ] + }; + const assetResponse: AssetsQueryByIdResponse = Object.assign(asset, { + _response: undefined! + }); + const artifactResponse: GetArtifactContentInformation2Response = Object.assign({ + contentUri: 'downloadUrl' + }, { + _response: undefined! + }); + + testContext.config.setup(x => x.amlApiVersion).returns(() => '2018'); + testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url'); + testContext.config.setup(x => x.amlExperienceUrl).returns(() => 'test.url'); + testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(), + TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('modelmanagement')), TypeMoq.It.isAny())).returns(() => Promise.resolve(assetResponse)); + testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(), + TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('artifact')), TypeMoq.It.isAny())).returns(() => Promise.resolve(artifactResponse)); + testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { + operationInfo.operation(testContext.op); + }); + testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve()); + let service = new AzureModelRegistryService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.httpClient.object, + testContext.outputChannel); + service.AzureMachineLearningClient = testContext.client.object; + service.ModelClient = testContext.modelClient.object; + testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models)); + let actual = await service.downloadModel(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0], testContext.models[0]); + should.notEqual(actual, undefined); + testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); + }); +}); diff --git a/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts new file mode 100644 index 0000000000..f324ed11b1 --- /dev/null +++ b/extensions/machine-learning-services/src/test/modelManagement/deployedModelService.test.ts @@ -0,0 +1,410 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as utils from '../../common/utils'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as TypeMoq from 'typemoq'; +import * as should from 'should'; +import { Config } from '../../configurations/config'; +import { DeployedModelService } from '../../modelManagement/deployedModelService'; +import { QueryRunner } from '../../common/queryRunner'; +import { RegisteredModel } from '../../modelManagement/interfaces'; +import { ModelPythonClient } from '../../modelManagement/modelPythonClient'; +import * as path from 'path'; +import * as os from 'os'; +import * as UUID from 'vscode-languageclient/lib/utils/uuid'; +import * as fs from 'fs'; + +interface TestContext { + + apiWrapper: TypeMoq.IMock; + config: TypeMoq.IMock; + queryRunner: TypeMoq.IMock; + modelClient: TypeMoq.IMock; +} + +function createContext(): TestContext { + + return { + apiWrapper: TypeMoq.Mock.ofType(ApiWrapper), + config: TypeMoq.Mock.ofType(Config), + queryRunner: TypeMoq.Mock.ofType(QueryRunner), + modelClient: TypeMoq.Mock.ofType(ModelPythonClient) + }; +} + +describe('DeployedModelService', () => { + it('getDeployedModels should fail with no connection', async function (): Promise { + const testContext = createContext(); + let connection: azdata.connection.ConnectionProfile; + + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + await should(service.getDeployedModels()).rejected(); + }); + + it('getDeployedModels should returns models successfully', async function (): Promise { + const testContext = createContext(); + const connection = new azdata.connection.ConnectionProfile(); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + const expected: RegisteredModel[] = [ + { + id: 1, + artifactName: 'name1', + title: 'title1', + description: 'desc1', + created: '2018-01-01', + version: '1.1' + } + ]; + const result = { + rowCount: 1, + columnInfo: [], + rows: [ + [ + { + displayValue: '1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'name1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'title1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'desc1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '1.1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '2018-01-01', + isNull: false, + invariantCultureDisplayValue: '' + } + ] + ] + }; + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); + + testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); + testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table'); + const actual = await service.getDeployedModels(); + should.deepEqual(actual, expected); + }); + + it('loadModelParameters should load parameters using python client successfully', async function (): Promise { + const testContext = createContext(); + const expected = { + inputs: [ + { + 'name': 'p1', + 'type': 'int' + }, + { + 'name': 'p2', + 'type': 'varchar' + } + ], + outputs: [ + { + 'name': 'o1', + 'type': 'int' + }, + ] + }; + testContext.modelClient.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(expected)); + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + const actual = await service.loadModelParameters(''); + should.deepEqual(actual, expected); + }); + + it('downloadModel should download model successfully', async function (): Promise { + const testContext = createContext(); + const connection = new azdata.connection.ConnectionProfile(); + const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`); + await fs.promises.writeFile(tempFilePath, 'test'); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + const model: RegisteredModel = + { + id: 1, + artifactName: 'name1', + title: 'title1', + description: 'desc1', + created: '2018-01-01', + version: '1.1' + }; + const result = { + rowCount: 1, + columnInfo: [], + rows: [ + [ + { + displayValue: await utils.readFileInHex(tempFilePath), + isNull: false, + invariantCultureDisplayValue: '' + } + ] + ] + }; + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); + + testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); + testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table'); + testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + const actual = await service.downloadModel(model); + should.notEqual(actual, undefined); + }); + + it('deployLocalModel should returns models successfully', async function (): Promise { + const testContext = createContext(); + const connection = new azdata.connection.ConnectionProfile(); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + const model: RegisteredModel = + { + id: 1, + artifactName: 'name1', + title: 'title1', + description: 'desc1', + created: '2018-01-01', + version: '1.1' + }; + const row = [ + { + displayValue: '1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'name1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'title1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'desc1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '1.1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: '2018-01-01', + isNull: false, + invariantCultureDisplayValue: '' + } + ]; + const result = { + rowCount: 1, + columnInfo: [], + rows: [row] + }; + let updatedResult = { + rowCount: 1, + columnInfo: [], + rows: [row, row] + }; + let deployed = false; + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + testContext.modelClient.setup(x => x.deployModel(connection, '')).returns(() => { + deployed = true; + return Promise.resolve(); + }); + testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result); + }); + + testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db'); + testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table'); + testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + await should(service.deployLocalModel('', model)).resolved(); + }); + + it('getConfigureQuery should escape db name', async function (): Promise { + const testContext = createContext(); + const dbName = 'curre[n]tDb'; + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); + testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); + testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + const expected = ` + IF NOT EXISTS ( + SELECT [name] + FROM sys.databases + WHERE [name] = N'd[]b' + ) + CREATE DATABASE [d[[]]b] + GO + USE [d[[]]b] + IF EXISTS + ( SELECT [t.name], [s.name] + FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id + WHERE [t.name] = 'ta[b]le' + AND [s.name] = 'dbo' + ) + BEGIN + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name') + ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version') + ALTER TABLE [dbo].[ta[[b]]le] ADD [version] [varchar](256) NULL + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='created') + BEGIN + ALTER TABLE [dbo].[ta[[b]]le] ADD [created] [datetime] NULL + ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created + END + IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='description') + ALTER TABLE [dbo].[ta[[b]]le] ADD [description] [varchar](256) NULL + END + Else + BEGIN + CREATE TABLE [dbo].[ta[[b]]le]( + [artifact_id] [int] IDENTITY(1,1) NOT NULL, + [artifact_name] [varchar](256) NOT NULL, + [group_path] [varchar](256) NOT NULL, + [artifact_content] [varbinary](max) NOT NULL, + [artifact_initial_size] [bigint] NULL, + [name] [varchar](256) NULL, + [version] [varchar](256) NULL, + [created] [datetime] NULL, + [description] [varchar](256) NULL, + CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED + ( + [artifact_id] ASC + )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] + ) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY] + ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created] + END + `; + const actual = service.getConfigureQuery(dbName); + should.equal(actual.indexOf(expected) > 0, true); + }); + + it('getDeployedModelsQuery should escape db name', async function (): Promise { + const testContext = createContext(); + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); + testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); + testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + const expected = ` + SELECT artifact_id, artifact_name, name, description, version, created + FROM [d[[]]b].[dbo].[ta[[b]]le] + WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml' + Order by artifact_id + `; + const actual = service.getDeployedModelsQuery(); + should.deepEqual(expected, actual); + }); + + it('getUpdateModelQuery should escape db name', async function (): Promise { + const testContext = createContext(); + const dbName = 'curre[n]tDb'; + const model: RegisteredModel = + { + id: 1, + artifactName: 'name1', + title: 'title1', + description: 'desc1', + created: '2018-01-01', + version: '1.1' + }; + + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); + testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); + testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + const expected = ` + UPDATE [dbo].[ta[[b]]le] + SET + name = 'title1', + version = '1.1', + description = 'desc1' + WHERE artifact_id = 1`; + const actual = service.getUpdateModelQuery(dbName, model); + should.equal(actual.indexOf(expected) > 0, true); + //should.deepEqual(actual, expected); + + }); + + it('getModelContentQuery should escape db name', async function (): Promise { + const testContext = createContext(); + const model: RegisteredModel = + { + id: 1, + artifactName: 'name1', + title: 'title1', + description: 'desc1', + created: '2018-01-01', + version: '1.1' + }; + + let service = new DeployedModelService( + testContext.apiWrapper.object, + testContext.config.object, + testContext.queryRunner.object, + testContext.modelClient.object); + testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b'); + testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le'); + testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo'); + const expected = ` + SELECT artifact_content + FROM [d[[]]b].[dbo].[ta[[b]]le] + WHERE artifact_id = 1; + `; + const actual = service.getModelContentQuery(model); + should.deepEqual(actual, expected); + }); +}); diff --git a/extensions/machine-learning-services/src/test/modelManagement/modelPythonClient.test.ts b/extensions/machine-learning-services/src/test/modelManagement/modelPythonClient.test.ts new file mode 100644 index 0000000000..a2b985a016 --- /dev/null +++ b/extensions/machine-learning-services/src/test/modelManagement/modelPythonClient.test.ts @@ -0,0 +1,121 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as vscode from 'vscode'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as TypeMoq from 'typemoq'; +import * as should from 'should'; +import { Config } from '../../configurations/config'; + +import * as utils from '../utils'; +import { ProcessService } from '../../common/processService'; +import { PackageManager } from '../../packageManagement/packageManager'; +import { ModelPythonClient } from '../../modelManagement/modelPythonClient'; + +interface TestContext { + + apiWrapper: TypeMoq.IMock; + config: TypeMoq.IMock; + outputChannel: vscode.OutputChannel; + op: azdata.BackgroundOperation; + processService: TypeMoq.IMock; + packageManager: TypeMoq.IMock; +} + +function createContext(): TestContext { + const context = utils.createContext(); + + return { + apiWrapper: TypeMoq.Mock.ofType(ApiWrapper), + config: TypeMoq.Mock.ofType(Config), + outputChannel: context.outputChannel, + op: context.op, + processService: TypeMoq.Mock.ofType(ProcessService), + packageManager: TypeMoq.Mock.ofType(PackageManager) + }; +} + +describe('ModelPythonClient', () => { + it('deployModel should deploy the model successfully', async function (): Promise { + const testContext = createContext(); + const connection = new azdata.connection.ConnectionProfile(); + const modelPath = 'C:\\test'; + let service = new ModelPythonClient( + testContext.outputChannel, + testContext.apiWrapper.object, + testContext.processService.object, + testContext.config.object, + testContext.packageManager.object); + testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { + operationInfo.operation(testContext.op); + }); + testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath'); + testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), + TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('')); + + await service.deployModel(connection, modelPath); + }); + + it('loadModelParameters should load model parameters successfully', async function (): Promise { + const testContext = createContext(); + const modelPath = 'C:\\test'; + const expected = { + inputs: [ + { + 'name': 'p1', + 'type': 'int' + }, + { + 'name': 'p2', + 'type': 'varchar' + } + ], + outputs: [ + { + 'name': 'o1', + 'type': 'int' + }, + ] + }; + const parametersJson = ` + { + "inputs": [ + { + "name": "p1", + "type": "int" + }, + { + "name": "p2", + "type": "varchar" + } + ], + "outputs": [ + { + "name": "o1", + "type": "int" + } + ] + } + `; + let service = new ModelPythonClient( + testContext.outputChannel, + testContext.apiWrapper.object, + testContext.processService.object, + testContext.config.object, + testContext.packageManager.object); + testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath'); + testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), + TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson)); + testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { + operationInfo.operation(testContext.op); + }); + + const actual = await service.loadModelParameters(modelPath); + should.deepEqual(actual, expected); + }); +}); diff --git a/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts b/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts index 856135603a..0ee3592e29 100644 --- a/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts +++ b/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts @@ -354,7 +354,7 @@ describe('SQL Python Package Manager', () => { let provider = createProvider(testContext); let actual = await provider.getLocationTitle(); - should.deepEqual(actual, constants.packageManagerNoConnection); + should.deepEqual(actual, constants.noConnectionError); }); it('getLocationTitle Should return connection title string for valid connection', async function (): Promise { diff --git a/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts b/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts index cbe86dc4a7..6752a00dc1 100644 --- a/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts +++ b/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts @@ -279,7 +279,7 @@ describe('SQL R Package Manager', () => { let provider = createProvider(testContext); let actual = await provider.getLocationTitle(); - should.deepEqual(actual, constants.packageManagerNoConnection); + should.deepEqual(actual, constants.noConnectionError); }); it('getLocationTitle Should return connection title string for valid connection', async function (): Promise { diff --git a/extensions/machine-learning-services/src/test/packageManagement/utils.ts b/extensions/machine-learning-services/src/test/packageManagement/utils.ts index 3911af4db0..1957d38d8c 100644 --- a/extensions/machine-learning-services/src/test/packageManagement/utils.ts +++ b/extensions/machine-learning-services/src/test/packageManagement/utils.ts @@ -11,6 +11,7 @@ import { QueryRunner } from '../../common/queryRunner'; import { ProcessService } from '../../common/processService'; import { Config } from '../../configurations/config'; import { HttpClient } from '../../common/httpClient'; +import * as utils from '../utils'; import { PackageManagementService } from '../../packageManagement/packageManagementService'; export interface TestContext { @@ -27,31 +28,18 @@ export interface TestContext { } export function createContext(): TestContext { - let opStatus: azdata.TaskStatus; + const context = utils.createContext(); return { - outputChannel: { - name: '', - append: () => { }, - appendLine: () => { }, - clear: () => { }, - show: () => { }, - hide: () => { }, - dispose: () => { } - }, + + outputChannel: context.outputChannel, processService: TypeMoq.Mock.ofType(ProcessService), apiWrapper: TypeMoq.Mock.ofType(ApiWrapper), queryRunner: TypeMoq.Mock.ofType(QueryRunner), config: TypeMoq.Mock.ofType(Config), httpClient: TypeMoq.Mock.ofType(HttpClient), - op: { - updateStatus: (status: azdata.TaskStatus) => { - opStatus = status; - }, - id: '', - onCanceled: new vscode.EventEmitter().event, - }, - getOpStatus: () => { return opStatus; }, + op: context.op, + getOpStatus: context.getOpStatus, serverConfigManager: TypeMoq.Mock.ofType(PackageManagementService) }; } diff --git a/extensions/machine-learning-services/src/test/prediction/predictService.test.ts b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts new file mode 100644 index 0000000000..94b3a8cd84 --- /dev/null +++ b/extensions/machine-learning-services/src/test/prediction/predictService.test.ts @@ -0,0 +1,303 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as vscode from 'vscode'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as TypeMoq from 'typemoq'; +import * as should from 'should'; +import { Config } from '../../configurations/config'; +import { PredictService } from '../../prediction/predictService'; +import { QueryRunner } from '../../common/queryRunner'; +import { RegisteredModel } from '../../modelManagement/interfaces'; +import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; +import * as path from 'path'; +import * as os from 'os'; +import * as UUID from 'vscode-languageclient/lib/utils/uuid'; +import * as fs from 'fs'; + + +interface TestContext { + + apiWrapper: TypeMoq.IMock; + config: TypeMoq.IMock; + queryRunner: TypeMoq.IMock; +} + +function createContext(): TestContext { + + return { + apiWrapper: TypeMoq.Mock.ofType(ApiWrapper), + config: TypeMoq.Mock.ofType(Config), + queryRunner: TypeMoq.Mock.ofType(QueryRunner) + }; +} + +describe('PredictService', () => { + + it('getDatabaseList should return databases successfully', async function (): Promise { + const testContext = createContext(); + const expected: string[] = [ + 'db1', + 'db2' + ]; + const connection = new azdata.connection.ConnectionProfile(); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + testContext.apiWrapper.setup(x => x.listDatabases(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(expected); }); + + let service = new PredictService( + testContext.apiWrapper.object, + testContext.queryRunner.object, + testContext.config.object); + const actual = await service.getDatabaseList(); + should.deepEqual(actual, expected); + }); + + it('getTableList should return tables successfully', async function (): Promise { + const testContext = createContext(); + const expected: DatabaseTable[] = [ + { + databaseName: 'db1', + schema: 'dbo', + tableName: 'tb1' + }, + { + databaseName: 'db1', + tableName: 'tb2', + schema: 'dbo' + } + ]; + + const result = { + rowCount: 1, + columnInfo: [], + rows: [[ + { + displayValue: 'tb1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'dbo', + isNull: false, + invariantCultureDisplayValue: '' + } + ], [ + { + displayValue: 'tb2', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'dbo', + isNull: false, + invariantCultureDisplayValue: '' + } + ]] + }; + const connection = new azdata.connection.ConnectionProfile(); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); + let service = new PredictService( + testContext.apiWrapper.object, + testContext.queryRunner.object, + testContext.config.object); + const actual = await service.getTableList('db1'); + should.deepEqual(actual, expected); + }); + + it('getTableColumnsList should return table columns successfully', async function (): Promise { + const testContext = createContext(); + const expected: TableColumn[] = [ + { + columnName: 'c1', + dataType: 'int' + }, + { + columnName: 'c2', + dataType: 'varchar' + } + ]; + const table: DatabaseTable = + { + databaseName: 'db1', + schema: 'dbo', + tableName: 'tb1' + }; + + const result = { + rowCount: 1, + columnInfo: [], + rows: [[ + { + displayValue: 'c1', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'int', + isNull: false, + invariantCultureDisplayValue: '' + } + ], [ + { + displayValue: 'c2', + isNull: false, + invariantCultureDisplayValue: '' + }, + { + displayValue: 'varchar', + isNull: false, + invariantCultureDisplayValue: '' + } + ]] + }; + const connection = new azdata.connection.ConnectionProfile(); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + + testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result)); + let service = new PredictService( + testContext.apiWrapper.object, + testContext.queryRunner.object, + testContext.config.object); + const actual = await service.getTableColumnsList(table); + should.deepEqual(actual, expected); + }); + + it('generatePredictScript should generate the script successfully using model', async function (): Promise { + const testContext = createContext(); + const connection = new azdata.connection.ConnectionProfile(); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + const predictParams: PredictParameters = { + inputColumns: [ + { + paramName: 'p1', + dataType: 'int', + columnName: '' + }, + { + paramName: 'p2', + dataType: 'varchar', + columnName: '' + } + ], + outputColumns: [ + { + paramName: 'o1', + dataType: 'int', + columnName: '' + }, + ], + databaseName: '', + tableName: '', + schema: '' + }; + const model: RegisteredModel = + { + id: 1, + artifactName: 'name1', + title: 'title1', + description: 'desc1', + created: '2018-01-01', + version: '1.1' + }; + + let service = new PredictService( + testContext.apiWrapper.object, + testContext.queryRunner.object, + testContext.config.object); + + const document: vscode.TextDocument = { + uri: vscode.Uri.parse('file:///usr/home'), + fileName: '', + isUntitled: true, + languageId: 'sql', + version: 1, + isDirty: true, + isClosed: false, + save: undefined!, + eol: undefined!, + lineCount: 1, + lineAt: undefined!, + offsetAt: undefined!, + positionAt: undefined!, + getText: undefined!, + getWordRangeAtPosition: undefined!, + validateRange: undefined!, + validatePosition: undefined! + }; + testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document)); + testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { }); + + const actual = await service.generatePredictScript(predictParams, model, undefined); + should.notEqual(actual, undefined); + should.equal(actual.indexOf('FROM PREDICT(MODEL = @model') > 0, true); + }); + + it('generatePredictScript should generate the script successfully using file', async function (): Promise { + const testContext = createContext(); + const connection = new azdata.connection.ConnectionProfile(); + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + const predictParams: PredictParameters = { + inputColumns: [ + { + paramName: 'p1', + dataType: 'int', + columnName: '' + }, + { + paramName: 'p2', + dataType: 'varchar', + columnName: '' + } + ], + outputColumns: [ + { + paramName: 'o1', + dataType: 'int', + columnName: '' + }, + ], + databaseName: '', + tableName: '', + schema: '' + }; + const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`); + await fs.promises.writeFile(tempFilePath, 'test'); + + let service = new PredictService( + testContext.apiWrapper.object, + testContext.queryRunner.object, + testContext.config.object); + + const document: vscode.TextDocument = { + uri: vscode.Uri.parse('file:///usr/home'), + fileName: '', + isUntitled: true, + languageId: 'sql', + version: 1, + isDirty: true, + isClosed: false, + save: undefined!, + eol: undefined!, + lineCount: 1, + lineAt: undefined!, + offsetAt: undefined!, + positionAt: undefined!, + getText: undefined!, + getWordRangeAtPosition: undefined!, + validateRange: undefined!, + validatePosition: undefined! + }; + testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document)); + testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { }); + + const actual = await service.generatePredictScript(predictParams, undefined, tempFilePath); + should.notEqual(actual, undefined); + should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true); + }); +}); diff --git a/extensions/machine-learning-services/src/test/utils.ts b/extensions/machine-learning-services/src/test/utils.ts new file mode 100644 index 0000000000..420e35a506 --- /dev/null +++ b/extensions/machine-learning-services/src/test/utils.ts @@ -0,0 +1,38 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as vscode from 'vscode'; +import * as azdata from 'azdata'; + +export interface TestContext { + + outputChannel: vscode.OutputChannel; + op: azdata.BackgroundOperation; + getOpStatus: () => azdata.TaskStatus; +} + +export function createContext(): TestContext { + let opStatus: azdata.TaskStatus; + + return { + outputChannel: { + name: '', + append: () => { }, + appendLine: () => { }, + clear: () => { }, + show: () => { }, + hide: () => { }, + dispose: () => { } + }, + op: { + updateStatus: (status: azdata.TaskStatus) => { + opStatus = status; + }, + id: '', + onCanceled: new vscode.EventEmitter().event, + }, + getOpStatus: () => { return opStatus; } + }; +} diff --git a/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts new file mode 100644 index 0000000000..c5b40ceb11 --- /dev/null +++ b/extensions/machine-learning-services/src/test/views/models/predictWizard.test.ts @@ -0,0 +1,178 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as azdata from 'azdata'; +import * as should from 'should'; +import 'mocha'; +import { createContext } from './utils'; +import { + ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, + ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName +} + from '../../../views/models/modelViewBase'; +import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces'; +import { azureResource } from '../../../typings/azure-resource'; +import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; +import { ViewBase } from '../../../views/viewBase'; +import { WorkspaceModel } from '../../../modelManagement/interfaces'; +import { PredictWizard } from '../../../views/models/prediction/predictWizard'; +import { DatabaseTable, TableColumn } from '../../../prediction/interfaces'; + +describe('Predict Wizard', () => { + it('Should create view components successfully ', async function (): Promise { + let testContext = createContext(); + + let view = new PredictWizard(testContext.apiWrapper.object, ''); + await view.open(); + should.notEqual(view.wizardView, undefined); + should.notEqual(view.modelSourcePage, undefined); + }); + + it('Should load data successfully ', async function (): Promise { + let testContext = createContext(); + + let view = new PredictWizard(testContext.apiWrapper.object, ''); + await view.open(); + let accounts: azdata.Account[] = [ + { + key: { + accountId: '1', + providerId: '' + }, + displayInfo: { + displayName: 'account', + userId: '', + accountType: '', + contextualDisplayName: '' + }, + isStale: false, + properties: [] + } + ]; + let subscriptions: azureResource.AzureResourceSubscription[] = [ + { + name: 'subscription', + id: '2' + } + ]; + let groups: azureResource.AzureResourceResourceGroup[] = [ + { + name: 'group', + id: '3' + } + ]; + let workspaces: Workspace[] = [ + { + name: 'workspace', + id: '4' + } + ]; + let models: WorkspaceModel[] = [ + { + id: '5', + name: 'model' + } + ]; + let localModels: RegisteredModel[] = [ + { + id: 1, + artifactName: 'model', + title: 'model' + } + ]; + const dbNames: string[] = [ + 'db1', + 'db2' + ]; + const tableNames: DatabaseTable[] = [ + { + databaseName: 'db1', + schema: 'dbo', + tableName: 'tb1' + }, + { + databaseName: 'db1', + tableName: 'tb2', + schema: 'dbo' + } + ]; + const columnNames: TableColumn[] = [ + { + columnName: 'c1', + dataType: 'int' + }, + { + columnName: 'c2', + dataType: 'varchar' + } + ]; + const modelParameters: ModelParameters = { + inputs: [ + { + 'name': 'p1', + 'type': 'int' + }, + { + 'name': 'p2', + 'type': 'varchar' + } + ], + outputs: [ + { + 'name': 'o1', + 'type': 'int' + } + ] + }; + + view.on(ListModelsEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels }); + }); + view.on(ListAccountsEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts }); + }); + view.on(ListSubscriptionsEventName, () => { + + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions }); + }); + view.on(ListGroupsEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups }); + }); + view.on(ListWorkspacesEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces }); + }); + view.on(ListAzureModelsEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models }); + }); + view.on(ListDatabaseNamesEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames }); + }); + view.on(ListTableNamesEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames }); + }); + view.on(ListColumnNamesEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames }); + }); + view.on(LoadModelParametersEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters }); + }); + view.on(DownloadAzureModelEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' }); + }); + view.on(DownloadRegisteredModelEventName, () => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' }); + }); + await view.refresh(); + should.notEqual(view.azureModelsComponent?.data, undefined); + should.notEqual(view.localModelsComponent?.data, undefined); + + should.notEqual(await view.getModelFileName(), undefined); + await view.columnsSelectionPage?.onEnter(); + + should.notEqual(view.columnsSelectionPage?.data, undefined); + should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name); + should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length); + }); +}); diff --git a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts index 21c37a93ac..d4f542fc0c 100644 --- a/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts +++ b/extensions/machine-learning-services/src/test/views/models/registerModelWizard.test.ts @@ -20,7 +20,7 @@ describe('Register Model Wizard', () => { let testContext = createContext(); let view = new RegisterModelWizard(testContext.apiWrapper.object, ''); - view.open(); + await view.open(); await view.refresh(); should.notEqual(view.wizardView, undefined); should.notEqual(view.modelSourcePage, undefined); @@ -30,7 +30,7 @@ describe('Register Model Wizard', () => { let testContext = createContext(); let view = new RegisterModelWizard(testContext.apiWrapper.object, ''); - view.open(); + await view.open(); let accounts: azdata.Account[] = [ { key: { @@ -98,5 +98,7 @@ describe('Register Model Wizard', () => { view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models }); }); await view.refresh(); + should.notEqual(view.azureModelsComponent?.data ,undefined); + should.notEqual(view.localModelsComponent?.data, undefined); }); }); diff --git a/extensions/machine-learning-services/src/test/views/models/utils.ts b/extensions/machine-learning-services/src/test/views/models/utils.ts index 964ecd93e8..c65150a998 100644 --- a/extensions/machine-learning-services/src/test/views/models/utils.ts +++ b/extensions/machine-learning-services/src/test/views/models/utils.ts @@ -7,14 +7,12 @@ import * as azdata from 'azdata'; import * as vscode from 'vscode'; import * as TypeMoq from 'typemoq'; import { ApiWrapper } from '../../../common/apiWrapper'; -import * as mssql from '../../../../../mssql/src/mssql'; import { createViewContext } from '../utils'; import { ModelViewBase } from '../../../views/models/modelViewBase'; export interface TestContext { apiWrapper: TypeMoq.IMock; view: azdata.ModelView; - languageExtensionService: mssql.ILanguageExtensionService; onClick: vscode.EventEmitter; } @@ -34,16 +32,10 @@ export class ParentDialog extends ModelViewBase { export function createContext(): TestContext { let viewTestContext = createViewContext(); - let languageExtensionService: mssql.ILanguageExtensionService = { - listLanguages: () => { return Promise.resolve([]); }, - deleteLanguage: () => { return Promise.resolve(); }, - updateLanguage: () => { return Promise.resolve(); } - }; return { apiWrapper: viewTestContext.apiWrapper, view: viewTestContext.view, - languageExtensionService: languageExtensionService, onClick: viewTestContext.onClick }; } diff --git a/extensions/machine-learning-services/src/test/views/utils.ts b/extensions/machine-learning-services/src/test/views/utils.ts index 3ae57623f4..7f36593597 100644 --- a/extensions/machine-learning-services/src/test/views/utils.ts +++ b/extensions/machine-learning-services/src/test/views/utils.ts @@ -62,6 +62,9 @@ export function createViewContext(): ViewTestContext { onTextChanged: undefined!, onEnterKeyPressed: undefined!, value: '' + }); + let image: () => azdata.ImageComponent = () => Object.assign({}, componentBase, { + }); let dropdown: () => azdata.DropDownComponent = () => Object.assign({}, componentBase, { onValueChanged: onClick.event, @@ -124,6 +127,14 @@ export function createViewContext(): ViewTestContext { withProperties: () => inputBoxBuilder, withValidation: () => inputBoxBuilder }; + let imageBuilder: azdata.ComponentBuilder = { + component: () => { + let r = image(); + return r; + }, + withProperties: () => imageBuilder, + withValidation: () => imageBuilder + }; let dropdownBuilder: azdata.ComponentBuilder = { component: () => { let r = dropdown(); @@ -156,7 +167,7 @@ export function createViewContext(): ViewTestContext { editor: undefined!, diffeditor: undefined!, text: () => inputBoxBuilder, - image: undefined!, + image: () => imageBuilder, button: () => buttonBuilder, dropDown: () => dropdownBuilder, tree: undefined!, @@ -181,7 +192,7 @@ export function createViewContext(): ViewTestContext { try { await handler(view); } catch (err) { - console.log(err); + throw err; } }, onValidityChanged: undefined!, @@ -242,7 +253,13 @@ export function createViewContext(): ViewTestContext { enabled: true, description: '', onValidityChanged: onClick.event, - registerContent: () => { }, + registerContent: async (handler) => { + try { + await handler(view); + } catch (err) { + throw err; + } + }, modelView: undefined!, valid: true }; diff --git a/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts b/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts index c652fcb507..9e3cc04550 100644 --- a/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts +++ b/extensions/machine-learning-services/src/views/externalLanguages/languageViewBase.ts @@ -107,14 +107,14 @@ export abstract class LanguageViewBase { if (connection) { return `${connection.serverName} ${connection.databaseName ? connection.databaseName : constants.extLangLocal}`; } - return constants.packageManagerNoConnection; + return constants.noConnectionError; } public getServerTitle(): string { if (this.connection) { return this.connection.serverName; } - return constants.packageManagerNoConnection; + return constants.noConnectionError; } private async getCurrentConnectionUrl(): Promise { diff --git a/extensions/machine-learning-services/src/views/interfaces.ts b/extensions/machine-learning-services/src/views/interfaces.ts index 028a656e16..bae2be5bd4 100644 --- a/extensions/machine-learning-services/src/views/interfaces.ts +++ b/extensions/machine-learning-services/src/views/interfaces.ts @@ -18,6 +18,7 @@ export interface IPageView { onLeave?: () => Promise; validate?: () => Promise; refresh: () => Promise; + disposePage?: () => Promise; viewPanel: azdata.window.ModelViewPanel | undefined; title: string; } diff --git a/extensions/machine-learning-services/src/views/mainViewBase.ts b/extensions/machine-learning-services/src/views/mainViewBase.ts index fa3e5b342e..ec996d6d30 100644 --- a/extensions/machine-learning-services/src/views/mainViewBase.ts +++ b/extensions/machine-learning-services/src/views/mainViewBase.ts @@ -37,6 +37,16 @@ export class MainViewBase { } } + public async disposePages(): Promise { + if (this._pages) { + await Promise.all(this._pages.map(async (p) => { + if (p.disposePage) { + await p.disposePage(); + } + })); + } + } + public async refresh(): Promise { if (this._pages) { await Promise.all(this._pages.map(async (p) => await p.refresh())); diff --git a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts index d04c1d0709..3dde5de665 100644 --- a/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/azureModelsComponent.ts @@ -9,6 +9,7 @@ import { ApiWrapper } from '../../common/apiWrapper'; import { AzureResourceFilterComponent } from './azureResourceFilterComponent'; import { AzureModelsTable } from './azureModelsTable'; import { IDataComponent, AzureModelResource } from '../interfaces'; +import { ModelArtifact } from './prediction/modelArtifact'; export class AzureModelsComponent extends ModelViewBase implements IDataComponent { @@ -17,6 +18,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen private _loader: azdata.LoadingComponent | undefined; private _form: azdata.FormContainer | undefined; + private _downloadedFile: ModelArtifact | undefined; /** * Component to render a view to pick an azure model @@ -37,8 +39,14 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen .withProperties({ loading: true }).component(); + this.azureModelsTable.onModelSelectionChanged(async () => { + if (this._downloadedFile) { + await this._downloadedFile.close(); + } + this._downloadedFile = undefined; + }); - this.azureFilterComponent.onWorkspacesSelected(async () => { + this.azureFilterComponent.onWorkspacesSelectedChanged(async () => { await this.onLoading(); await this.azureModelsTable?.loadData(this.azureFilterComponent?.data); await this.onLoaded(); @@ -107,6 +115,22 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen }); } + public async getDownloadedModel(): Promise { + if (!this._downloadedFile) { + this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(this.data)); + } + return this._downloadedFile; + } + + /** + * disposes the view + */ + public async disposeComponent(): Promise { + if (this._downloadedFile) { + await this._downloadedFile.close(); + } + } + /** * Refreshes the view */ diff --git a/extensions/machine-learning-services/src/views/models/azureModelsTable.ts b/extensions/machine-learning-services/src/views/models/azureModelsTable.ts index 37c3caea51..ea128c7cf7 100644 --- a/extensions/machine-learning-services/src/views/models/azureModelsTable.ts +++ b/extensions/machine-learning-services/src/views/models/azureModelsTable.ts @@ -4,6 +4,7 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; +import * as vscode from 'vscode'; import * as constants from '../../common/constants'; import { ModelViewBase } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; @@ -18,6 +19,8 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent = new vscode.EventEmitter(); + public readonly onModelSelectionChanged: vscode.Event = this._onModelSelectionChanged.event; /** * Creates a view to render azure models in a table @@ -115,6 +118,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent { this._selectedModelId = model.id; + this._onModelSelectionChanged.fire(); }); return [model.name, model.createdTime, model.frameworkVersion, selectModelButton]; } diff --git a/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts b/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts index 43b9fba599..ad36ab8a33 100644 --- a/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts +++ b/extensions/machine-learning-services/src/views/models/azureResourceFilterComponent.ts @@ -27,8 +27,8 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData private _azureSubscriptions: azureResource.AzureResourceSubscription[] = []; private _azureGroups: azureResource.AzureResource[] = []; private _azureWorkspaces: Workspace[] = []; - private _onWorkspacesSelected: vscode.EventEmitter = new vscode.EventEmitter(); - public readonly onWorkspacesSelected: vscode.Event = this._onWorkspacesSelected.event; + private _onWorkspacesSelectedChanged: vscode.EventEmitter = new vscode.EventEmitter(); + public readonly onWorkspacesSelectedChanged: vscode.Event = this._onWorkspacesSelectedChanged.event; /** * Creates a new view @@ -59,7 +59,7 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData await this.onGroupSelected(); }); this._workspaces.onValueChanged(async () => { - await this.onWorkspaceSelected(); + await this.onWorkspaceSelectedChanged(); }); this._form = this._modelBuilder.formContainer().withFormItems([{ @@ -182,26 +182,26 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData this._workspaces.values = values; this._workspaces.value = values[0]; } - this.onWorkspaceSelected(); + this.onWorkspaceSelectedChanged(); } - private onWorkspaceSelected(): void { - this._onWorkspacesSelected.fire(); + private onWorkspaceSelectedChanged(): void { + this._onWorkspacesSelectedChanged.fire(); } private get workspace(): Workspace | undefined { - return this._azureWorkspaces ? this._azureWorkspaces.find(a => a.id === (this._workspaces.value).name) : undefined; + return this._azureWorkspaces && this._workspaces.value ? this._azureWorkspaces.find(a => a.id === (this._workspaces.value).name) : undefined; } private get account(): azdata.Account | undefined { - return this._azureAccounts ? this._azureAccounts.find(a => a.key.accountId === (this._accounts.value).name) : undefined; + return this._azureAccounts && this._accounts.value ? this._azureAccounts.find(a => a.key.accountId === (this._accounts.value).name) : undefined; } private get group(): azureResource.AzureResource | undefined { - return this._azureGroups ? this._azureGroups.find(a => a.id === (this._groups.value).name) : undefined; + return this._azureGroups && this._groups.value ? this._azureGroups.find(a => a.id === (this._groups.value).name) : undefined; } private get subscription(): azureResource.AzureResourceSubscription | undefined { - return this._azureSubscriptions ? this._azureSubscriptions.find(a => a.id === (this._subscriptions.value).name) : undefined; + return this._azureSubscriptions && this._subscriptions.value ? this._azureSubscriptions.find(a => a.id === (this._subscriptions.value).name) : undefined; } } diff --git a/extensions/machine-learning-services/src/views/models/modelManagementController.ts b/extensions/machine-learning-services/src/views/models/modelManagementController.ts index 5076d56b37..97c4bcfe71 100644 --- a/extensions/machine-learning-services/src/views/models/modelManagementController.ts +++ b/extensions/machine-learning-services/src/views/models/modelManagementController.ts @@ -9,15 +9,15 @@ import { azureResource } from '../../typings/azure-resource'; import { ApiWrapper } from '../../common/apiWrapper'; import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; -import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces'; -import { PredictParameters, DatabaseTable } from '../../prediction/interfaces'; -import { RegisteredModelService } from '../../modelManagement/registeredModelService'; +import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces'; +import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; +import { DeployedModelService } from '../../modelManagement/deployedModelService'; import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog'; import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName, - ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs + ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName } from './modelViewBase'; import { ControllerBase } from '../controllerBase'; import { RegisterModelWizard } from './registerModels/registerModelWizard'; @@ -39,7 +39,7 @@ export class ModelManagementController extends ControllerBase { apiWrapper: ApiWrapper, private _root: string, private _amlService: AzureModelRegistryService, - private _registeredModelService: RegisteredModelService, + private _registeredModelService: DeployedModelService, private _predictService: PredictService) { super(apiWrapper); } @@ -61,7 +61,7 @@ export class ModelManagementController extends ControllerBase { // Open view // - view.open(); + await view.open(); await view.refresh(); return view; } @@ -74,10 +74,15 @@ export class ModelManagementController extends ControllerBase { let view = new PredictWizard(this._apiWrapper, this._root); this.registerEvents(view); + view.on(LoadModelParametersEventName, async () => { + const modelArtifact = await view.getModelFileName(); + await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService, + modelArtifact?.filePath); + }); // Open view // - view.open(); + await view.open(); await view.refresh(); return view; } @@ -151,6 +156,11 @@ export class ModelManagementController extends ControllerBase { await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService, predictArgs, predictArgs.model, predictArgs.filePath); }); + view.on(DownloadRegisteredModelEventName, async (arg) => { + let model = arg; + await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService, + model); + }); view.on(SourceModelSelectedEventName, () => { view.refresh(); }); @@ -191,8 +201,8 @@ export class ModelManagementController extends ControllerBase { return await service.getWorkspaces(account, subscription, group); } - private async getRegisteredModels(registeredModelService: RegisteredModelService): Promise { - return registeredModelService.getRegisteredModels(); + private async getRegisteredModels(registeredModelService: DeployedModelService): Promise { + return registeredModelService.getDeployedModels(); } private async getAzureModels( @@ -207,9 +217,9 @@ export class ModelManagementController extends ControllerBase { return await service.getModels(account, subscription, resourceGroup, workspace) || []; } - private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise { + private async registerLocalModel(service: DeployedModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise { if (filePath) { - await service.registerLocalModel(filePath, details); + await service.deployLocalModel(filePath, details); } else { throw Error(constants.invalidModelToRegisterError); @@ -218,7 +228,7 @@ export class ModelManagementController extends ControllerBase { private async registerAzureModel( azureService: AzureModelRegistryService, - service: RegisteredModelService, + service: DeployedModelService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, resourceGroup: azureResource.AzureResource | undefined, @@ -231,7 +241,7 @@ export class ModelManagementController extends ControllerBase { const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model); if (filePath) { - await service.registerLocalModel(filePath, details); + await service.deployLocalModel(filePath, details); await fs.promises.unlink(filePath); } else { throw Error(constants.invalidModelToRegisterError); @@ -246,7 +256,7 @@ export class ModelManagementController extends ControllerBase { return await predictService.getTableList(databaseName); } - public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise { + public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise { return await predictService.getTableColumnsList(databaseTable); } @@ -263,6 +273,24 @@ export class ModelManagementController extends ControllerBase { return result; } + private async downloadRegisteredModel( + registeredModelService: DeployedModelService, + model: RegisteredModel | undefined): Promise { + if (!model) { + throw Error(constants.invalidModelToPredictError); + } + return await registeredModelService.downloadModel(model); + } + + private async loadModelParameters( + registeredModelService: DeployedModelService, + model: string | undefined): Promise { + if (!model) { + return undefined; + } + return await registeredModelService.loadModelParameters(model); + } + private async downloadAzureModel( azureService: AzureModelRegistryService, account: azdata.Account | undefined, diff --git a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts index a3f521f4f2..c8b046b5c7 100644 --- a/extensions/machine-learning-services/src/views/models/modelSourcePage.ts +++ b/extensions/machine-learning-services/src/views/models/modelSourcePage.ts @@ -120,4 +120,14 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo } return Promise.resolve(validated); } + + public async disposePage(): Promise { + if (this.azureModelsComponent) { + await this.azureModelsComponent.disposeComponent(); + + } + if (this.registeredModelsComponent) { + await this.registeredModelsComponent.disposeComponent(); + } + } } diff --git a/extensions/machine-learning-services/src/views/models/modelViewBase.ts b/extensions/machine-learning-services/src/views/models/modelViewBase.ts index 6b1b39da06..5b69ecc19b 100644 --- a/extensions/machine-learning-services/src/views/models/modelViewBase.ts +++ b/extensions/machine-learning-services/src/views/models/modelViewBase.ts @@ -8,8 +8,8 @@ import * as azdata from 'azdata'; import { azureResource } from '../../typings/azure-resource'; import { ApiWrapper } from '../../common/apiWrapper'; import { ViewBase } from '../viewBase'; -import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces'; -import { PredictParameters, DatabaseTable } from '../../prediction/interfaces'; +import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces'; +import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { AzureWorkspaceResource, AzureModelResource } from '../interfaces'; @@ -47,9 +47,11 @@ export const ListWorkspacesEventName = 'listWorkspaces'; export const RegisterLocalModelEventName = 'registerLocalModel'; export const RegisterAzureModelEventName = 'registerAzureLocalModel'; export const DownloadAzureModelEventName = 'downloadAzureLocalModel'; +export const DownloadRegisteredModelEventName = 'downloadRegisteredModel'; export const PredictModelEventName = 'predictModel'; export const RegisterModelEventName = 'registerModel'; export const SourceModelSelectedEventName = 'sourceModelSelected'; +export const LoadModelParametersEventName = 'loadModelParameters'; /** * Base class for all model management views @@ -75,7 +77,9 @@ export abstract class ModelViewBase extends ViewBase { ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, - DownloadAzureModelEventName]); + DownloadAzureModelEventName, + DownloadRegisteredModelEventName, + LoadModelParametersEventName]); } /** @@ -124,7 +128,7 @@ export abstract class ModelViewBase extends ViewBase { /** * lists column names */ - public async listColumnNames(table: DatabaseTable): Promise { + public async listColumnNames(table: DatabaseTable): Promise { return await this.sendDataRequest(ListColumnNamesEventName, table); } @@ -151,6 +155,14 @@ export abstract class ModelViewBase extends ViewBase { return await this.sendDataRequest(RegisterLocalModelEventName, args); } + /** + * downloads registered model + * @param model model to download + */ + public async downloadRegisteredModel(model: RegisteredModel | undefined): Promise { + return await this.sendDataRequest(DownloadRegisteredModelEventName, model); + } + /** * download azure model * @param args azure resource @@ -159,6 +171,13 @@ export abstract class ModelViewBase extends ViewBase { return await this.sendDataRequest(DownloadAzureModelEventName, resource); } + /** + * Loads model parameters + */ + public async loadModelParameters(): Promise { + return await this.sendDataRequest(LoadModelParametersEventName); + } + /** * registers azure model * @param args azure resource diff --git a/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts b/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts index f0c622442d..6eb45172df 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/columnsSelectionPage.ts @@ -8,7 +8,7 @@ import { ModelViewBase } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; import * as constants from '../../../common/constants'; import { IPageView, IDataComponent } from '../../interfaces'; -import { ColumnsFilterComponent } from './columnsFilterComponent'; +import { InputColumnsComponent } from './inputColumnsComponent'; import { OutputColumnsComponent } from './outputColumnsComponent'; import { PredictParameters } from '../../../prediction/interfaces'; @@ -19,7 +19,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID private _form: azdata.FormContainer | undefined; private _formBuilder: azdata.FormBuilder | undefined; - public columnsFilterComponent: ColumnsFilterComponent | undefined; + public inputColumnsComponent: InputColumnsComponent | undefined; public outputColumnsComponent: OutputColumnsComponent | undefined; constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { @@ -32,15 +32,14 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this._formBuilder = modelBuilder.formContainer(); - this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this); - this.columnsFilterComponent.registerComponent(modelBuilder); - this.columnsFilterComponent.addComponents(this._formBuilder); - this.refresh(); + this.inputColumnsComponent = new InputColumnsComponent(this._apiWrapper, this); + this.inputColumnsComponent.registerComponent(modelBuilder); + this.inputColumnsComponent.addComponents(this._formBuilder); this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this); this.outputColumnsComponent.registerComponent(modelBuilder); this.outputColumnsComponent.addComponents(this._formBuilder); - this.refresh(); + this._form = this._formBuilder.component(); return this._form; } @@ -49,8 +48,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID * Returns selected data */ public get data(): PredictParameters | undefined { - return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ? - Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) : + return this.inputColumnsComponent?.data && this.outputColumnsComponent?.data ? + Object.assign({}, this.inputColumnsComponent.data, { outputColumns: this.outputColumnsComponent.data }) : undefined; } @@ -66,8 +65,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID */ public async refresh(): Promise { if (this._formBuilder) { - if (this.columnsFilterComponent) { - await this.columnsFilterComponent.refresh(); + if (this.inputColumnsComponent) { + await this.inputColumnsComponent.refresh(); } if (this.outputColumnsComponent) { await this.outputColumnsComponent.refresh(); @@ -75,6 +74,24 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID } } + public async onEnter(): Promise { + await this.inputColumnsComponent?.onLoading(); + await this.outputColumnsComponent?.onLoading(); + try { + const modelParameters = await this.loadModelParameters(); + if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) { + this.inputColumnsComponent.modelParameters = modelParameters; + this.outputColumnsComponent.modelParameters = modelParameters; + await this.inputColumnsComponent.refresh(); + await this.outputColumnsComponent.refresh(); + } + } catch (error) { + this.showErrorMessage(constants.loadModelParameterFailedError, error); + } + await this.inputColumnsComponent?.onLoaded(); + await this.outputColumnsComponent?.onLoaded(); + } + /** * Returns page title */ diff --git a/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts index 7230f7f63a..717cae893f 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/columnsTable.ts @@ -8,133 +8,280 @@ import * as constants from '../../../common/constants'; import { ModelViewBase } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; import { IDataComponent } from '../../interfaces'; -import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces'; +import { PredictColumn, DatabaseTable, TableColumn } from '../../../prediction/interfaces'; +import { ModelParameter, ModelParameters } from '../../../modelManagement/interfaces'; /** * View to render azure models in a table */ export class ColumnsTable extends ModelViewBase implements IDataComponent { - private _table: azdata.DeclarativeTableComponent; - private _selectedColumns: PredictColumn[] = []; - private _columns: string[] | undefined; + private _table: azdata.DeclarativeTableComponent | undefined; + private _parameters: PredictColumn[] = []; + private _loader: azdata.LoadingComponent; + private _dataTypes: string[] = [ + 'bigint', + 'int', + 'smallint', + 'real', + 'float', + 'varchar(MAX)', + 'bit' + ]; + /** * Creates a view to render azure models in a table */ - constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) { + constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _forInput: boolean = true) { super(apiWrapper, parent.root, parent); - this._table = this.registerComponent(this._modelBuilder); + this._loader = this.registerComponent(this._modelBuilder); } /** * Register components * @param modelBuilder model builder */ - public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent { + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.LoadingComponent { + let columnHeader: azdata.DeclarativeTableColumn[]; + if (this._forInput) { + columnHeader = [ + { // Action + displayName: constants.columnName, + ariaLabel: constants.columnName, + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Name + displayName: '', + ariaLabel: '', + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Name + displayName: constants.inputName, + ariaLabel: constants.inputName, + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 120, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + } + ]; + } else { + columnHeader = [ + { // Name + displayName: constants.outputName, + ariaLabel: constants.outputName, + valueType: azdata.DeclarativeDataType.string, + isReadOnly: true, + width: 200, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Action + displayName: constants.displayName, + ariaLabel: constants.displayName, + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Action + displayName: constants.dataTypeName, + ariaLabel: constants.dataTypeName, + valueType: azdata.DeclarativeDataType.component, + isReadOnly: true, + width: 50, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + } + ]; + } this._table = modelBuilder.declarativeTable() + .withProperties( { - columns: [ - { // Name - displayName: constants.columnDatabase, - ariaLabel: constants.columnName, - valueType: azdata.DeclarativeDataType.string, - isReadOnly: true, - width: 120, - headerCssStyles: { - ...constants.cssStyles.tableHeader - }, - rowCssStyles: { - ...constants.cssStyles.tableRow - }, - }, - { // Action - displayName: constants.inputName, - ariaLabel: constants.inputName, - valueType: azdata.DeclarativeDataType.component, - isReadOnly: true, - width: 50, - headerCssStyles: { - ...constants.cssStyles.tableHeader - }, - rowCssStyles: { - ...constants.cssStyles.tableRow - }, - }, - { // Action - displayName: '', - valueType: azdata.DeclarativeDataType.component, - isReadOnly: true, - width: 50, - headerCssStyles: { - ...constants.cssStyles.tableHeader - }, - rowCssStyles: { - ...constants.cssStyles.tableRow - }, - } - ], + columns: columnHeader, data: [], ariaLabel: constants.mlsConfigTitle }) .component(); - return this._table; + this._loader = modelBuilder.loadingComponent() + .withItem(this._table) + .withProperties({ + loading: true + }).component(); + return this._loader; } - public get component(): azdata.DeclarativeTableComponent { - return this._table; + public async onLoading(): Promise { + if (this._loader) { + await this._loader.updateProperties({ loading: true }); + } + } + + public async onLoaded(): Promise { + if (this._loader) { + await this._loader.updateProperties({ loading: false }); + } + } + + public get component(): azdata.Component { + return this._loader; } /** * Load data in the component * @param workspaceResource Azure workspace */ - public async loadData(table: DatabaseTable): Promise { - this._selectedColumns = []; - if (this._table) { - this._columns = await this.listColumnNames(table); - let tableData: any[][] = []; + public async loadInputs(modelParameters: ModelParameters | undefined, table: DatabaseTable): Promise { + await this.onLoading(); + this._parameters = []; + let tableData: any[][] = []; - if (this._columns) { - tableData = tableData.concat(this._columns.map(model => this.createTableRow(model))); + if (this._table) { + if (this._forInput) { + const columns = await this.listColumnNames(table); + if (modelParameters?.inputs && columns) { + tableData = tableData.concat(modelParameters.inputs.map(input => this.createInputTableRow(input, columns))); + } } this._table.data = tableData; } + await this.onLoaded(); } - private createTableRow(column: string): any[] { - if (this._modelBuilder) { - let selectRowButton = this._modelBuilder.checkBox().withProperties({ + public async loadOutputs(modelParameters: ModelParameters | undefined): Promise { + this.onLoading(); + this._parameters = []; + let tableData: any[][] = []; - width: 15, - height: 15, - checked: true + if (this._table) { + if (!this._forInput) { + if (modelParameters?.outputs && this._dataTypes) { + tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes))); + } + } + + this._table.data = tableData; + } + this.onLoaded(); + } + + private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] { + if (this._modelBuilder) { + + let nameInput = this._modelBuilder.dropDown().withProperties({ + values: dataTypes, + width: this.componentMaxLength }).component(); - let nameInputBox = this._modelBuilder.inputBox().withProperties({ - value: '', - width: 150 - }).component(); - this._selectedColumns.push({ name: column }); - selectRowButton.onChanged(() => { - if (selectRowButton.checked) { - if (!this._selectedColumns.find(x => x.name === column)) { - this._selectedColumns.push({ name: column }); - } - } else { - if (this._selectedColumns.find(x => x.name === column)) { - this._selectedColumns = this._selectedColumns.filter(x => x.name !== column); + const name = modelParameter.name; + const dataType = dataTypes.find(x => x === modelParameter.type); + if (dataType) { + nameInput.value = dataType; + } + this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type }); + + nameInput.onValueChanged(() => { + const value = nameInput.value; + if (value !== modelParameter.type) { + let selectedRow = this._parameters.find(x => x.paramName === name); + if (selectedRow) { + selectedRow.dataType = value; } } }); - nameInputBox.onTextChanged(() => { - let selectedRow = this._selectedColumns.find(x => x.name === column); + let displayNameInput = this._modelBuilder.inputBox().withProperties({ + value: name, + width: 200 + }).component(); + displayNameInput.onTextChanged(() => { + let selectedRow = this._parameters.find(x => x.paramName === name); if (selectedRow) { - selectedRow.displayName = nameInputBox.value; + selectedRow.columnName = displayNameInput.value || name; } }); - return [column, nameInputBox, selectRowButton]; + return [`${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, displayNameInput, nameInput]; + } + + return []; + } + + private createInputTableRow(modelParameter: ModelParameter, columns: TableColumn[] | undefined): any[] { + if (this._modelBuilder && columns) { + const values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; }); + let nameInput = this._modelBuilder.dropDown().withProperties({ + values: values, + width: this.componentMaxLength + }).component(); + const name = modelParameter.name; + let column = values.find(x => x.name === modelParameter.name); + if (!column) { + column = values[0]; + } + nameInput.value = column; + + this._parameters.push({ columnName: column.name, paramName: name }); + + nameInput.onValueChanged(() => { + const selectedColumn = nameInput.value; + const value = selectedColumn ? (selectedColumn).name : undefined; + + let selectedRow = this._parameters.find(x => x.paramName === name); + if (selectedRow) { + selectedRow.columnName = value || ''; + } + }); + const label = this._modelBuilder.inputBox().withProperties({ + value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, + enabled: false, + width: this.componentMaxLength + }).component(); + const image = this._modelBuilder.image().withProperties({ + width: 50, + height: 50, + iconPath: { + dark: this.asAbsolutePath('images/arrow.svg'), + light: this.asAbsolutePath('images/arrow.svg') + }, + iconWidth: 20, + iconHeight: 20, + title: 'maps' + }).component(); + return [nameInput, image, label]; } return []; @@ -144,7 +291,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent { +export class InputColumnsComponent extends ModelViewBase implements IDataComponent { private _form: azdata.FormContainer | undefined; private _databases: azdata.DropDownComponent | undefined; @@ -22,7 +23,9 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon private _columns: ColumnsTable | undefined; private _dbNames: string[] = []; private _tableNames: DatabaseTable[] = []; - + private _modelParameters: ModelParameters | undefined; + private _dbTableComponent: azdata.FlexContainer | undefined; + private tableMaxLength = this.componentMaxLength * 2 + 70; /** * Creates a new view */ @@ -52,27 +55,47 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon }); - this._form = modelBuilder.formContainer().withFormItems([{ - title: constants.azureAccount, + const databaseForm = modelBuilder.formContainer().withFormItems([{ + title: constants.columnDatabase, component: this._databases - }, { - title: constants.azureSubscription, + }]).withLayout({ + padding: '0px' + }).component(); + const tableForm = modelBuilder.formContainer().withFormItems([{ + title: constants.columnTable, component: this._tables + }]).withLayout({ + padding: '0px' + }).component(); + this._dbTableComponent = modelBuilder.flexContainer().withItems([ + databaseForm, + tableForm + ], { + flex: '0 0 auto', + CSSStyles: { + 'align-items': 'flex-start' + } + }).withLayout({ + flexFlow: 'row', + justifyContent: 'space-between', + width: this.tableMaxLength + }).component(); + + this._form = modelBuilder.formContainer().withFormItems([{ + title: '', + component: this._dbTableComponent }, { - title: constants.azureGroup, + title: constants.inputColumns, component: this._columns.component }]).component(); return this._form; } public addComponents(formBuilder: azdata.FormBuilder) { - if (this._databases && this._tables && this._columns) { + if (this._columns && this._dbTableComponent) { formBuilder.addFormItems([{ - title: constants.columnDatabase, - component: this._databases - }, { - title: constants.columnTable, - component: this._tables + title: '', + component: this._dbTableComponent }, { title: constants.inputColumns, component: this._columns.component @@ -81,17 +104,13 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon } public removeComponents(formBuilder: azdata.FormBuilder) { - if (this._databases && this._tables && this._columns) { + if (this._columns && this._dbTableComponent) { formBuilder.removeFormItem({ - title: constants.azureAccount, - component: this._databases + title: '', + component: this._dbTableComponent }); formBuilder.removeFormItem({ - title: constants.azureSubscription, - component: this._tables - }); - formBuilder.removeFormItem({ - title: constants.azureGroup, + title: constants.inputColumns, component: this._columns.component }); } @@ -125,6 +144,22 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon await this.onDatabaseSelected(); } + public set modelParameters(value: ModelParameters) { + this._modelParameters = value; + } + + public async onLoading(): Promise { + if (this._columns) { + await this._columns.onLoading(); + } + } + + public async onLoaded(): Promise { + if (this._columns) { + await this._columns.onLoaded(); + } + } + /** * refreshes the view */ @@ -146,7 +181,7 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon } private async onTableSelected(): Promise { - this._columns?.loadData(this.databaseTable); + this._columns?.loadInputs(this._modelParameters, this.databaseTable); } private get databaseName(): string | undefined { diff --git a/extensions/machine-learning-services/src/views/models/prediction/modelArtifact.ts b/extensions/machine-learning-services/src/views/models/prediction/modelArtifact.ts new file mode 100644 index 0000000000..4f2501d249 --- /dev/null +++ b/extensions/machine-learning-services/src/views/models/prediction/modelArtifact.ts @@ -0,0 +1,35 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as utils from '../../../common/utils'; + +/** +* Wizard to register a model +*/ +export class ModelArtifact { + + /** + * Creates new model artifact + */ + constructor(private _filePath: string, private _deleteAtClose: boolean = true) { + } + + public get filePath(): string { + return this._filePath; + } + + /** + * Closes the artifact and disposes the resources + */ + public async close(): Promise { + if (this._deleteAtClose) { + try { + await utils.deleteFile(this._filePath); + } catch { + + } + } + } +} diff --git a/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts b/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts index 35782e9492..7dfbbb9684 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/outputColumnsComponent.ts @@ -9,25 +9,18 @@ import { ApiWrapper } from '../../../common/apiWrapper'; import * as constants from '../../../common/constants'; import { IDataComponent } from '../../interfaces'; import { PredictColumn } from '../../../prediction/interfaces'; +import { ColumnsTable } from './columnsTable'; +import { ModelParameters } from '../../../modelManagement/interfaces'; /** * View to render filters to pick an azure resource */ -const componentWidth = 60; + export class OutputColumnsComponent extends ModelViewBase implements IDataComponent { private _form: azdata.FormContainer | undefined; - private _flex: azdata.FlexContainer | undefined; - private _columnName: azdata.InputBoxComponent | undefined; - private _columnTypes: azdata.DropDownComponent | undefined; - private _dataTypes: string[] = [ - 'int', - 'nvarchar(MAX)', - 'varchar(MAX)', - 'float', - 'double', - 'bit' - ]; + private _columns: ColumnsTable | undefined; + private _modelParameters: ModelParameters | undefined; /** * Creates a new view @@ -41,49 +34,29 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon * @param modelBuilder model builder */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._columnName = modelBuilder.inputBox().withProperties({ - width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength - }).component(); - this._columnTypes = modelBuilder.dropDown().withProperties({ - width: componentWidth - }).component(); - - let flex = modelBuilder.flexContainer() - .withLayout({ - width: this._columnName.width - }).withItems([ - this._columnName] - ).component(); - this._flex = modelBuilder.flexContainer() - .withLayout({ - flexFlow: 'row', - justifyContent: 'space-between', - width: this.componentMaxLength - }).withItems([ - flex, this._columnTypes] - ).component(); + this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this, false); this._form = modelBuilder.formContainer().withFormItems([{ title: constants.azureAccount, - component: this._flex + component: this._columns.component }]).component(); return this._form; } public addComponents(formBuilder: azdata.FormBuilder) { - if (this._flex) { + if (this._columns) { formBuilder.addFormItems([{ title: constants.outputColumns, - component: this._flex + component: this._columns.component }]); } } public removeComponents(formBuilder: azdata.FormBuilder) { - if (this._flex) { + if (this._columns) { formBuilder.removeFormItem({ title: constants.outputColumns, - component: this._flex + component: this._columns.component }); } } @@ -99,9 +72,24 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon * loads data in the components */ public async loadData(): Promise { - if (this._columnTypes) { - this._columnTypes.values = this._dataTypes; - this._columnTypes.value = this._dataTypes[0]; + if (this._modelParameters) { + this._columns?.loadOutputs(this._modelParameters); + } + } + + public set modelParameters(value: ModelParameters) { + this._modelParameters = value; + } + + public async onLoading(): Promise { + if (this._columns) { + await this._columns.onLoading(); + } + } + + public async onLoaded(): Promise { + if (this._columns) { + await this._columns.onLoaded(); } } @@ -116,9 +104,6 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon * Returns selected data */ public get data(): PredictColumn[] | undefined { - return this._columnName && this._columnTypes ? [{ - name: this._columnName.value || '', - dataType: this._columnTypes.value || '' - }] : undefined; + return this._columns?.data; } } diff --git a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts index c7c3784f72..61d41bebf0 100644 --- a/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts +++ b/extensions/machine-learning-services/src/views/models/prediction/predictWizard.ts @@ -14,6 +14,7 @@ import { WizardView } from '../../wizardView'; import { ModelSourcePage } from '../modelSourcePage'; import { ColumnsSelectionPage } from './columnsSelectionPage'; import { RegisteredModel } from '../../../modelManagement/interfaces'; +import { ModelArtifact } from './modelArtifact'; /** * Wizard to register a model @@ -21,7 +22,6 @@ import { RegisteredModel } from '../../../modelManagement/interfaces'; export class PredictWizard extends ModelViewBase { public modelSourcePage: ModelSourcePage | undefined; - //public modelDetailsPage: ModelDetailsPage | undefined; public columnsSelectionPage: ColumnsSelectionPage | undefined; public wizardView: WizardView | undefined; private _parentView: ModelViewBase | undefined; @@ -37,7 +37,7 @@ export class PredictWizard extends ModelViewBase { /** * Opens a dialog to manage packages used by notebooks. */ - public open(): void { + public async open(): Promise { this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]); this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this); this.wizardView = new WizardView(this._apiWrapper); @@ -50,16 +50,22 @@ export class PredictWizard extends ModelViewBase { wizard.doneButton.label = constants.predictModel; wizard.generateScriptButton.hidden = true; wizard.displayPageTitles = true; + wizard.doneButton.onClick(async () => { + await this.onClose(); + }); + wizard.cancelButton.onClick(async () => { + await this.onClose(); + }); wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => { let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; - if (validated && pageInfo.newPage === undefined) { - wizard.cancelButton.enabled = false; - wizard.backButton.enabled = false; - await this.predict(); - wizard.cancelButton.enabled = true; - wizard.backButton.enabled = true; - if (this._parentView) { - this._parentView?.refresh(); + if (validated) { + if (pageInfo.newPage === undefined) { + this.onLoading(); + await this.predict(); + this.onLoaded(); + if (this._parentView) { + this._parentView?.refresh(); + } } return true; @@ -67,7 +73,22 @@ export class PredictWizard extends ModelViewBase { return validated; }); - wizard.open(); + await wizard.open(); + } + + private onLoading(): void { + this.refreshButtons(true); + } + + private onLoaded(): void { + this.refreshButtons(false); + } + + private refreshButtons(loading: boolean): void { + if (this.wizardView && this.wizardView.wizard) { + this.wizardView.wizard.cancelButton.enabled = !loading; + this.wizardView.wizard.cancelButton.enabled = !loading; + } } public get modelResources(): ModelSourcesComponent | undefined { @@ -82,16 +103,26 @@ export class PredictWizard extends ModelViewBase { return this.modelSourcePage?.azureModelsComponent; } + public async getModelFileName(): Promise { + if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) { + return new ModelArtifact(this.localModelsComponent.data, false); + } else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) { + return await this.azureModelsComponent.getDownloadedModel(); + } else if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) { + return await this.modelSourcePage.registeredModelsComponent.getDownloadedModel(); + } + return undefined; + } + private async predict(): Promise { try { - let modelFilePath: string = ''; + let modelFilePath: string | undefined; let registeredModel: RegisteredModel | undefined = undefined; - if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) { - modelFilePath = this.localModelsComponent.data; - } else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) { - modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data); - } else { + if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) { registeredModel = this.modelSourcePage?.registeredModelsComponent?.data; + } else { + const artifact = await this.getModelFileName(); + modelFilePath = artifact?.filePath; } await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data); @@ -102,6 +133,14 @@ export class PredictWizard extends ModelViewBase { } } + private async onClose(): Promise { + const artifact = await this.getModelFileName(); + if (artifact) { + artifact.close(); + } + await this.wizardView?.disposePages(); + } + /** * Refresh the pages */ diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts index b113858ef7..2723c82dc6 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsPage.ts @@ -15,7 +15,7 @@ import { IPageView } from '../../interfaces'; * View to render current registered models */ export class CurrentModelsPage extends ModelViewBase implements IPageView { - private _tableComponent: azdata.DeclarativeTableComponent | undefined; + private _tableComponent: azdata.Component | undefined; private _dataTable: CurrentModelsTable | undefined; private _loader: azdata.LoadingComponent | undefined; diff --git a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts index 34b91d488c..7efd3bdc8b 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/currentModelsTable.ts @@ -4,11 +4,13 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; +import * as vscode from 'vscode'; import * as constants from '../../../common/constants'; import { ModelViewBase } from '../modelViewBase'; import { ApiWrapper } from '../../../common/apiWrapper'; import { RegisteredModel } from '../../../modelManagement/interfaces'; import { IDataComponent } from '../../interfaces'; +import { ModelArtifact } from '../prediction/modelArtifact'; /** * View to render registered models table @@ -18,6 +20,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< private _table: azdata.DeclarativeTableComponent | undefined; private _modelBuilder: azdata.ModelBuilder | undefined; private _selectedModel: any; + private _loader: azdata.LoadingComponent | undefined; + private _downloadedFile: ModelArtifact | undefined; + private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter(); + public readonly onModelSelectionChanged: vscode.Event = this._onModelSelectionChanged.event; /** * Creates new view @@ -30,7 +36,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< * * @param modelBuilder register the components */ - public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent { + public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this._modelBuilder = modelBuilder; this._table = modelBuilder.declarativeTable() .withProperties( @@ -92,7 +98,12 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< ariaLabel: constants.mlsConfigTitle }) .component(); - return this._table; + this._loader = modelBuilder.loadingComponent() + .withItem(this._table) + .withProperties({ + loading: true + }).component(); + return this._loader; } public addComponents(formBuilder: azdata.FormBuilder) { @@ -111,14 +122,15 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< /** * Returns the component */ - public get component(): azdata.DeclarativeTableComponent | undefined { - return this._table; + public get component(): azdata.Component | undefined { + return this._loader; } /** * Loads the data in the component */ public async loadData(): Promise { + await this.onLoading(); if (this._table) { let models: RegisteredModel[] | undefined; @@ -131,6 +143,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< this._table.data = tableData; } + this.onModelSelected(); + await this.onLoaded(); + } + + public async onLoading(): Promise { + if (this._loader) { + await this._loader.updateProperties({ loading: true }); + } + } + + public async onLoaded(): Promise { + if (this._loader) { + await this._loader.updateProperties({ loading: false }); + } } private createTableRow(model: RegisteredModel): any[] { @@ -142,8 +168,9 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< height: 15, checked: false }).component(); - selectModelButton.onDidClick(() => { + selectModelButton.onDidClick(async () => { this._selectedModel = model; + await this.onModelSelected(); }); return [model.artifactName, model.title, model.created, selectModelButton]; } @@ -151,6 +178,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< return []; } + private async onModelSelected(): Promise { + this._onModelSelectionChanged.fire(); + if (this._downloadedFile) { + await this._downloadedFile.close(); + } + this._downloadedFile = undefined; + } + /** * Returns selected data */ @@ -158,6 +193,22 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< return this._selectedModel; } + public async getDownloadedModel(): Promise { + if (!this._downloadedFile) { + this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data)); + } + return this._downloadedFile; + } + + /** + * disposes the view + */ + public async disposeComponent(): Promise { + if (this._downloadedFile) { + await this._downloadedFile.close(); + } + } + /** * Refreshes the view */ diff --git a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts index 73013595f1..2d42bf00e7 100644 --- a/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts +++ b/extensions/machine-learning-services/src/views/models/registerModels/registerModelWizard.ts @@ -35,7 +35,7 @@ export class RegisterModelWizard extends ModelViewBase { /** * Opens a dialog to manage packages used by notebooks. */ - public open(): void { + public async open(): Promise { this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this); this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this); this.wizardView = new WizardView(this._apiWrapper); @@ -63,7 +63,7 @@ export class RegisterModelWizard extends ModelViewBase { return validated; }); - wizard.open(); + await wizard.open(); } public get modelResources(): ModelSourcesComponent | undefined { diff --git a/extensions/machine-learning-services/src/views/viewBase.ts b/extensions/machine-learning-services/src/views/viewBase.ts index 3702bab3bf..8cc36d9862 100644 --- a/extensions/machine-learning-services/src/views/viewBase.ts +++ b/extensions/machine-learning-services/src/views/viewBase.ts @@ -128,14 +128,14 @@ export abstract class ViewBase extends EventEmitterCollection { if (connection) { return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`; } - return constants.packageManagerNoConnection; + return constants.noConnectionError; } public getServerTitle(): string { if (this.connection) { return this.connection.serverName; } - return constants.packageManagerNoConnection; + return constants.noConnectionError; } private async getCurrentConnectionUrl(): Promise { diff --git a/extensions/machine-learning-services/src/views/wizardView.ts b/extensions/machine-learning-services/src/views/wizardView.ts index 33976ec0fb..49b1adef5a 100644 --- a/extensions/machine-learning-services/src/views/wizardView.ts +++ b/extensions/machine-learning-services/src/views/wizardView.ts @@ -68,7 +68,7 @@ export class WizardView extends MainViewBase { this._pages = pages; this._wizard.pages = pages.map(x => this.createWizardPage(x.title || '', x)); this._wizard.onPageChanged(async (info) => { - this.onWizardPageChanged(info); + await this.onWizardPageChanged(info); }); return this._wizard; @@ -85,17 +85,17 @@ export class WizardView extends MainViewBase { return true; } - private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) { + private async onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) { let idxLast = pageInfo.lastPage; let lastPage = this._pages[idxLast]; if (lastPage && lastPage.onLeave) { - lastPage.onLeave(); + await lastPage.onLeave(); } let idx = pageInfo.newPage; let page = this._pages[idx]; if (page && page.onEnter) { - page.onEnter(); + await page.onEnter(); } }