From 0258b1727a8ee9d291701e75f7a5b9673cf46dfe Mon Sep 17 00:00:00 2001 From: Leila Lali Date: Thu, 14 May 2020 12:46:47 -0700 Subject: [PATCH] Machine Learning - Bug fixes (#10377) * Fixing ML extension bugs --- extensions/machine-learning/package.json | 4 +- .../machine-learning/src/common/constants.ts | 10 +- .../machine-learning/src/common/utils.ts | 24 ++++ .../src/configurations/config.ts | 33 +++++- .../src/modelManagement/modelPythonClient.ts | 4 +- .../src/packageManagement/packageManager.ts | 25 ++-- .../sqlPythonPackageManageProvider.ts | 2 +- .../sqlRPackageManageProvider.ts | 2 +- .../src/prediction/predictService.ts | 12 +- .../src/test/common/config.test.ts | 110 +++++++++++++++++ .../modelManagement/modelPythonClient.test.ts | 4 +- .../packageManagement/packageManager.test.ts | 4 +- .../sqlPythonPackageManageProvider.test.ts | 2 +- .../sqlRPackageManageProvider.test.ts | 2 +- .../src/test/views/dashboardWidget.test.ts | 4 +- .../manageModels/currentModelsComponent.ts | 76 +++++++++--- .../models/manageModels/currentModelsTable.ts | 6 +- .../models/manageModels/importModelWizard.ts | 5 +- .../manageModels/modelImportLocationPage.ts | 61 +++++++--- .../views/models/prediction/columnsTable.ts | 2 +- .../views/models/tableSelectionComponent.ts | 112 ++++++++++-------- .../src/views/widgets/dashboardWidget.ts | 4 +- 22 files changed, 382 insertions(+), 126 deletions(-) create mode 100644 extensions/machine-learning/src/test/common/config.test.ts diff --git a/extensions/machine-learning/package.json b/extensions/machine-learning/package.json index 3a388bb5cc..544241b7a1 100644 --- a/extensions/machine-learning/package.json +++ b/extensions/machine-learning/package.json @@ -42,12 +42,12 @@ }, "machineLearningServices.pythonPath": { "type": "string", - "default": "python", + "default": "", "description": "%mls.pythonPath.description%" }, "machineLearningServices.rPath": { "type": "string", - "default": "r", + "default": "", "description": "%mls.rPath.description%" } } diff --git a/extensions/machine-learning/src/common/constants.ts b/extensions/machine-learning/src/common/constants.ts index a5a0690bac..3c6d1ba32a 100644 --- a/extensions/machine-learning/src/common/constants.ts +++ b/extensions/machine-learning/src/common/constants.ts @@ -41,6 +41,7 @@ export const pythonEnabledConfigKey = 'enablePython'; export const rEnabledConfigKey = 'enableR'; export const registeredModelsTableName = 'registeredModelsTableName'; export const rPathConfigKey = 'rPath'; +export const adsPythonBundleVersion = '0.0.1'; // Localized texts // @@ -48,8 +49,10 @@ export const msgYes = localize('msgYes', "Yes"); export const msgNo = localize('msgNo', "No"); export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed."); export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); } -export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Installing package management dependencies"); -export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Installing model management dependencies"); +export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); } +export function cannotFindR(path: string): string { return localize('mls.cannotFindR.error', "Cannot find R executable '{0}'. Please make sure R is installed and configured correctly", path); } +export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Verifying package management dependencies"); +export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Verifying model management dependencies"); export const noResultError = localize('mls.noResultError', "No Result returned"); export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed"); export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that."); @@ -122,6 +125,8 @@ export const extLangInstallFailedError = localize('extLang.installFailedError', export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language"); export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model"); +export const modelsListEmptyMessage = localize('models.modelsListEmptyMessage', "No Models Yet"); +export const modelsListEmptyDescription = localize('models.modelsListEmptyDescription', "Use import wizard to add models to this table"); export const databaseName = localize('databaseName', "Models database"); export const tableName = localize('tableName', "Models table"); export const existingTableName = localize('existingTableName', "Existing table"); @@ -195,6 +200,7 @@ export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMism export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!"); +export const selectModelsTableMessage = localize('models.selectModelsTableMessage', "Select models table"); export const modelSchemaIsNotAcceptedMessage = localize('models.modelSchemaIsNotAcceptedMessage', "Invalid table structure"); export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); } export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); } diff --git a/extensions/machine-learning/src/common/utils.ts b/extensions/machine-learning/src/common/utils.ts index f4113c715d..76bfe244c6 100644 --- a/extensions/machine-learning/src/common/utils.ts +++ b/extensions/machine-learning/src/common/utils.ts @@ -44,6 +44,15 @@ export async function exists(path: string): Promise { return promisify(fs.exists)(path); } +export async function isDirectory(path: string): Promise { + try { + const stat = await fs.promises.lstat(path); + return stat.isDirectory(); + } catch { + return false; + } +} + export async function createFolder(dirPath: string): Promise { let folderExists = await exists(dirPath); if (!folderExists) { @@ -259,3 +268,18 @@ export function getFileName(filePath: string) { return ''; } } + +export function getDefaultPythonLocation(): string { + + return path.join(getUserHome() || '', 'azuredatastudio-python', + constants.adsPythonBundleVersion, + getPythonExeName()); +} + +export function getPythonExeName(): string { + return process.platform === constants.winPlatform ? 'python.exe' : 'bin/python3'; +} + +export function getUserHome(): string | undefined { + return process.env.HOME || process.env.USERPROFILE; +} diff --git a/extensions/machine-learning/src/configurations/config.ts b/extensions/machine-learning/src/configurations/config.ts index 1da8625ef3..9171a62d71 100644 --- a/extensions/machine-learning/src/configurations/config.ts +++ b/extensions/machine-learning/src/configurations/config.ts @@ -9,10 +9,11 @@ import * as constants from '../common/constants'; import { promises as fs } from 'fs'; import * as path from 'path'; import { PackageConfigModel } from './packageConfigModel'; +import * as utils from '../common/utils'; const configFileName = 'config.json'; -const defaultPythonExecutable = 'python'; -const defaultRExecutable = 'r'; +const defaultPythonExecutable = ''; +const defaultRExecutable = ''; /** @@ -57,8 +58,22 @@ export class Config { /** * Returns python path from user settings */ - public get pythonExecutable(): string { - return this.config.get(constants.pythonPathConfigKey) || defaultPythonExecutable; + public async getPythonExecutable(verify: boolean): Promise { + let executable: string = this.config.get(constants.pythonPathConfigKey) || defaultPythonExecutable; + if (!executable) { + executable = utils.getDefaultPythonLocation(); + } else { + const exeName = utils.getPythonExeName(); + const isFolder = await utils.isDirectory(executable); + if (isFolder && executable.indexOf(exeName) < 0) { + executable = path.join(executable, exeName); + } + } + let checkExist = executable && executable.toLocaleUpperCase() !== 'PYTHON' && executable.toLocaleUpperCase() !== 'PYTHON3'; + if (verify && checkExist && !await utils.exists(executable)) { + throw new Error(constants.cannotFindPython(executable)); + } + return executable; } /** @@ -128,8 +143,14 @@ export class Config { /** * Returns r path from user settings */ - public get rExecutable(): string { - return this.config.get(constants.rPathConfigKey) || defaultRExecutable; + public async getRExecutable(verify: boolean): Promise { + let executable: string = this.config.get(constants.rPathConfigKey) || defaultRExecutable; + let checkExist = executable && executable.toLocaleUpperCase() !== 'R'; + if (verify && checkExist && !await utils.exists(executable)) { + throw new Error(constants.cannotFindR(executable)); + } + + return executable; } private get config(): vscode.WorkspaceConfiguration { diff --git a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts index 494070c1d8..bc5417256a 100644 --- a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts +++ b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts @@ -92,7 +92,7 @@ export class ModelPythonClient { 'addParameters(onnx_model.graph.output, "outputs")', 'print(json.dumps(parameters))' ]; - let pythonExecutable = this._config.pythonExecutable; + let pythonExecutable = await this._config.getPythonExecutable(true); let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined); let parametersJson = JSON.parse(output); return Object.assign({}, parametersJson); @@ -124,7 +124,7 @@ export class ModelPythonClient { 'mlflow.set_experiment(exp_name)', 'mlflow.onnx.log_model(onx, "pipeline_vectorize")' ]; - let pythonExecutable = this._config.pythonExecutable; + let pythonExecutable = await this._config.getPythonExecutable(true); await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel); } } diff --git a/extensions/machine-learning/src/packageManagement/packageManager.ts b/extensions/machine-learning/src/packageManagement/packageManager.ts index b949d2561b..98e98979fd 100644 --- a/extensions/machine-learning/src/packageManagement/packageManager.ts +++ b/extensions/machine-learning/src/packageManagement/packageManager.ts @@ -45,12 +45,12 @@ export class PackageManager { public init(): void { } - private get pythonExecutable(): string { - return this._config.pythonExecutable; + private async getPythonExecutable(): Promise { + return await this._config.getPythonExecutable(true); } - private get _rExecutable(): string { - return this._config.rExecutable; + private async getRExecutable(): Promise { + return await this._config.getRExecutable(true); } /** * Returns packageManageProviders @@ -123,7 +123,8 @@ export class PackageManager { if (!this._config.rEnabled) { return; } - if (!this._rExecutable) { + let rExecutable = await this.getRExecutable(); + if (!rExecutable) { throw new Error(constants.rConfigError); } @@ -139,7 +140,8 @@ export class PackageManager { if (!this._config.pythonEnabled) { return; } - if (!this.pythonExecutable) { + let pythonExecutable = await this.getPythonExecutable(); + if (!pythonExecutable) { throw new Error(constants.pythonConfigError); } if (!requiredPackages || requiredPackages.length === 0) { @@ -177,7 +179,8 @@ export class PackageManager { private async getInstalledPipPackages(): Promise { try { - let cmd = `"${this.pythonExecutable}" -m pip list --format=json`; + let pythonExecutable = await this.getPythonExecutable(); + let cmd = `"${pythonExecutable}" -m pip list --format=json`; let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined); let packagesResult: nbExtensionApis.IPackageDetails[] = []; if (packagesInfo && packagesInfo.indexOf(']') > 0) { @@ -196,23 +199,25 @@ export class PackageManager { } private async installPipPackage(requirementFilePath: string): Promise { - let cmd = `"${this.pythonExecutable}" -m pip install -r "${requirementFilePath}"`; + let pythonExecutable = await this.getPythonExecutable(); + let cmd = `"${pythonExecutable}" -m pip install -r "${requirementFilePath}"`; return await this._processService.executeBufferedCommand(cmd, this._outputChannel); } private async installRPackage(model: PackageConfigModel): Promise { let output = ''; let cmd = ''; + let rExecutable = await this.getRExecutable(); if (model.downloadUrl) { const packageFile = utils.getPackageFilePath(this._rootFolder, model.fileName || model.name); const packageExist = await utils.exists(packageFile); if (!packageExist) { await this._httpClient.download(model.downloadUrl, packageFile, this._outputChannel); } - cmd = `"${this._rExecutable}" CMD INSTALL ${packageFile}`; + cmd = `"${rExecutable}" CMD INSTALL ${packageFile}`; output = await this._processService.executeBufferedCommand(cmd, this._outputChannel); } else if (model.repository) { - cmd = `"${this._rExecutable}" -e "install.packages('${model.name}', repos='${model.repository}')"`; + cmd = `"${rExecutable}" -e "install.packages('${model.name}', repos='${model.repository}')"`; output = await this._processService.executeBufferedCommand(cmd, this._outputChannel); } return output; diff --git a/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts b/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts index 8639d701ab..642bbd0d92 100644 --- a/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts +++ b/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts @@ -84,7 +84,7 @@ export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase 'pkgmanager = sqlmlutils.SQLPackageManager(connection)', pythonCommandScript ]; - let pythonExecutable = this._config.pythonExecutable; + let pythonExecutable = await this._config.getPythonExecutable(true); await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel); } } diff --git a/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts b/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts index cc20d599ca..14548379e9 100644 --- a/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts +++ b/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts @@ -83,7 +83,7 @@ export class SqlRPackageManageProvider extends SqlPackageManageProviderBase impl `${rCommandScript}(connectionString = connection, pkgs, scope = "PUBLIC")`, 'q()' ]; - let rExecutable = this._config.rExecutable; + let rExecutable = await this._config.getRExecutable(true); await this._processService.execScripts(`${rExecutable}`, scripts, ['--vanilla'], this._outputChannel); } } diff --git a/extensions/machine-learning/src/prediction/predictService.ts b/extensions/machine-learning/src/prediction/predictService.ts index 7172bcf336..600d5ff972 100644 --- a/extensions/machine-learning/src/prediction/predictService.ts +++ b/extensions/machine-learning/src/prediction/predictService.ts @@ -176,7 +176,8 @@ AS ( FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi ) SELECT -${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getPredictInputColumnNames(outputColumns, 'p')} +${this.getPredictColumnNames(columns, 'predict_input')}, +${this.getPredictInputColumnNames(outputColumns, 'p')} FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx) WITH ( ${this.getOutputParameters(outputColumns)} @@ -197,7 +198,8 @@ AS ( FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi ) SELECT -${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')} +${this.getPredictColumnNames(columns, 'predict_input')}, +${this.getPredictInputColumnNames(outputColumns, 'p')} FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx) WITH ( ${this.getOutputParameters(outputColumns)} @@ -224,12 +226,6 @@ WITH ( }).join(',\n'); } - private getOutputColumnNames(columns: PredictColumn[], tableName: string) { - return columns.map(c => { - return this.getColumnName(tableName, c.columnName, c.paramName || ''); - }).join(',\n'); - } - private getColumnName(tableName: string, columnName: string, displayName: string) { const column = this.getEscapedColumnName(tableName, columnName); return columnName && columnName !== displayName ? diff --git a/extensions/machine-learning/src/test/common/config.test.ts b/extensions/machine-learning/src/test/common/config.test.ts new file mode 100644 index 0000000000..9781531011 --- /dev/null +++ b/extensions/machine-learning/src/test/common/config.test.ts @@ -0,0 +1,110 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as vscode from 'vscode'; +import { ApiWrapper } from '../../common/apiWrapper'; +import * as TypeMoq from 'typemoq'; +import * as should from 'should'; +import { Config } from '../../configurations/config'; +import * as utils from '../../common/utils'; +import * as path from 'path'; + +interface TestContext { + + apiWrapper: TypeMoq.IMock; +} + +function createContext(): TestContext { + return { + apiWrapper: TypeMoq.Mock.ofType(ApiWrapper) + }; +} + +let configData : vscode.WorkspaceConfiguration = { + get: () => {}, + has: () => true, + inspect: () => undefined, + update: () => {return Promise.resolve();}, + +}; + +describe('Config', () => { + it('getPythonExecutable should default to ADS python location is not configured', async function (): Promise { + const context = createContext(); + configData.get = () => { return ''; }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + const expected = utils.getDefaultPythonLocation(); + const actual = await config.getPythonExecutable(false); + should.deepEqual(actual, expected); + }); + + it('getPythonExecutable should add python executable name is folder path is configured', async function (): Promise { + const context = createContext(); + configData.get = () => { return utils.getUserHome(); }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + const expected = path.join(utils.getUserHome() || '', utils.getPythonExeName()); + const actual = await config.getPythonExecutable(false); + should.deepEqual(actual, expected); + }); + + it('getPythonExecutable should not add python executable if already added', async function (): Promise { + const context = createContext(); + configData.get = () => { return path.join(utils.getUserHome() || '', utils.getPythonExeName()); }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + const expected = path.join(utils.getUserHome() || '', utils.getPythonExeName()); + const actual = await config.getPythonExecutable(false); + should.deepEqual(actual, expected); + }); + + it('getPythonExecutable should not add python executable set to python', async function (): Promise { + const context = createContext(); + configData.get = () => { return 'python'; }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + const expected = 'python'; + const actual = await config.getPythonExecutable(false); + should.deepEqual(actual, expected); + }); + + it('getPythonExecutable should not add python executable set to python3', async function (): Promise { + const context = createContext(); + configData.get = () => { return 'python3'; }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + const expected = 'python3'; + const actual = await config.getPythonExecutable(false); + should.deepEqual(actual, expected); + }); + + it('getRExecutable should not add r executable set to r', async function (): Promise { + const context = createContext(); + configData.get = () => { return 'r'; }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + const expected = 'r'; + const actual = await config.getRExecutable(false); + should.deepEqual(actual, expected); + }); + + it('getPythonExecutable should throw error if file does not exist', async function (): Promise { + const context = createContext(); + configData.get = () => { return path.join(utils.getUserHome() || '', 'invalidPath'); }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + await should(config.getPythonExecutable(true)).be.rejected(); + }); + + it('getRExecutable should throw error if file does not exist', async function (): Promise { + const context = createContext(); + configData.get = () => { return path.join(utils.getUserHome() || '', 'invalidPath'); }; + context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData); + let config = new Config('', context.apiWrapper.object); + await should(config.getRExecutable(true)).be.rejected(); + }); + +}); diff --git a/extensions/machine-learning/src/test/modelManagement/modelPythonClient.test.ts b/extensions/machine-learning/src/test/modelManagement/modelPythonClient.test.ts index a2b985a016..222f0836f9 100644 --- a/extensions/machine-learning/src/test/modelManagement/modelPythonClient.test.ts +++ b/extensions/machine-learning/src/test/modelManagement/modelPythonClient.test.ts @@ -53,7 +53,7 @@ describe('ModelPythonClient', () => { testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { operationInfo.operation(testContext.op); }); - testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath'); + testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('pythonPath')); testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('')); @@ -108,7 +108,7 @@ describe('ModelPythonClient', () => { testContext.config.object, testContext.packageManager.object); testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve()); - testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath'); + testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('pythonPath')); testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson)); testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { diff --git a/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts b/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts index ca85657c92..66353dd73a 100644 --- a/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts +++ b/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts @@ -254,8 +254,8 @@ describe('Package Manager', () => { { name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'} ]); testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve()); - testContext.config.setup(x => x.pythonExecutable).returns(() => 'python'); - testContext.config.setup(x => x.rExecutable).returns(() => 'r'); + testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('python')); + testContext.config.setup(x => x.getRExecutable(true)).returns(() => Promise.resolve('r')); testContext.config.setup(x => x.rEnabled).returns(() => true); testContext.config.setup(x => x.pythonEnabled).returns(() => true); let packageManager = new PackageManager( diff --git a/extensions/machine-learning/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts b/extensions/machine-learning/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts index ceb2f50202..d8a38148c6 100644 --- a/extensions/machine-learning/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts +++ b/extensions/machine-learning/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts @@ -386,7 +386,7 @@ describe('SQL Python Package Manager', () => { }); function createProvider(testContext: TestContext): SqlPythonPackageManageProvider { - testContext.config.setup(x => x.pythonExecutable).returns(() => 'python'); + testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('python')); testContext.config.setup(x => x.pythonEnabled).returns(() => true); return new SqlPythonPackageManageProvider( testContext.outputChannel, diff --git a/extensions/machine-learning/src/test/packageManagement/sqlRPackageManageProvider.test.ts b/extensions/machine-learning/src/test/packageManagement/sqlRPackageManageProvider.test.ts index 4bf4146785..d46eb3b584 100644 --- a/extensions/machine-learning/src/test/packageManagement/sqlRPackageManageProvider.test.ts +++ b/extensions/machine-learning/src/test/packageManagement/sqlRPackageManageProvider.test.ts @@ -311,7 +311,7 @@ describe('SQL R Package Manager', () => { }); function createProvider(testContext: TestContext): SqlRPackageManageProvider { - testContext.config.setup(x => x.rExecutable).returns(() => 'r'); + testContext.config.setup(x => x.getRExecutable(true)).returns(() => Promise.resolve('r')); testContext.config.setup(x => x.rEnabled).returns(() => true); testContext.config.setup(x => x.rPackagesRepository).returns(() => 'http://cran.r-project.org'); return new SqlRPackageManageProvider( diff --git a/extensions/machine-learning/src/test/views/dashboardWidget.test.ts b/extensions/machine-learning/src/test/views/dashboardWidget.test.ts index 09390c68e0..377f70186d 100644 --- a/extensions/machine-learning/src/test/views/dashboardWidget.test.ts +++ b/extensions/machine-learning/src/test/views/dashboardWidget.test.ts @@ -39,10 +39,12 @@ describe('Dashboard widget', () => { await handler(testContext.view); }); + testContext.apiWrapper.setup(x => x.openExternal(TypeMoq.It.isAny())).returns(() => Promise.resolve(true)); + testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(true)); const dashboard = new DashboardWidget(testContext.apiWrapper.object, '', testContext.predictService.object); await dashboard.register(); testContext.onClick.fire(undefined); - testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce()); + testContext.apiWrapper.verify(x => x.openExternal(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce()); }); }); diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts index edb751114b..4b736638c5 100644 --- a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts +++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts @@ -17,10 +17,13 @@ import { ImportedModel } from '../../../modelManagement/interfaces'; * View to render current registered models */ export class CurrentModelsComponent extends ModelViewBase implements IPageView { - private _tableComponent: azdata.Component | undefined; private _dataTable: CurrentModelsTable | undefined; private _loader: azdata.LoadingComponent | undefined; private _tableSelectionComponent: TableSelectionComponent | undefined; + private _labelComponent: azdata.TextComponent | undefined; + private _descriptionComponent: azdata.TextComponent | undefined; + private _labelContainer: azdata.FlexContainer | undefined; + private _formBuilder: azdata.FormBuilder | undefined; /** * @@ -43,37 +46,69 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { }); this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings); this._dataTable.registerComponent(modelBuilder); - this._tableComponent = this._dataTable.component; let formModelBuilder = modelBuilder.formContainer(); - this._tableSelectionComponent.addComponents(formModelBuilder); - - if (this._tableComponent) { - formModelBuilder.addFormItem({ - component: this._tableComponent, - title: '' - }); - } - this._loader = modelBuilder.loadingComponent() .withItem(formModelBuilder.component()) .withProperties({ loading: true }).component(); + this._labelComponent = modelBuilder.text().withProperties({ + width: 200, + value: constants.modelsListEmptyMessage + }).component(); + this._descriptionComponent = modelBuilder.text().withProperties({ + width: 200, + value: constants.modelsListEmptyDescription + }).component(); + this._labelContainer = modelBuilder.flexContainer().withLayout({ + flexFlow: 'column', + width: 800, + height: '400px', + justifyContent: 'center' + }).component(); + + this._labelContainer.addItem( + this._labelComponent + , { + CSSStyles: { + 'align-items': 'center', + 'padding-top': '30px', + 'padding-left': `${this.componentMaxLength}px`, + 'font-size': '16px' + } + }); + this._labelContainer.addItem( + this._descriptionComponent + , { + CSSStyles: { + 'align-items': 'center', + 'padding-top': '10px', + 'padding-left': `${this.componentMaxLength - 50}px`, + 'font-size': '13px' + } + }); + + this.addComponents(formModelBuilder); return this._loader; } public addComponents(formBuilder: azdata.FormBuilder) { - if (this._tableSelectionComponent && this._dataTable) { + this._formBuilder = formBuilder; + if (this._tableSelectionComponent && this._dataTable && this._labelContainer) { this._tableSelectionComponent.addComponents(formBuilder); this._dataTable.addComponents(formBuilder); + if (this._dataTable.isEmpty) { + formBuilder.addFormItem({ title: '', component: this._labelContainer }); + } } } public removeComponents(formBuilder: azdata.FormBuilder) { - if (this._tableSelectionComponent && this._dataTable) { + if (this._tableSelectionComponent && this._dataTable && this._labelContainer) { this._tableSelectionComponent.removeComponents(formBuilder); this._dataTable.removeComponents(formBuilder); + formBuilder.removeFormItem({ title: '', component: this._labelContainer }); } } @@ -91,10 +126,11 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { await this.onLoading(); try { - if (this._tableSelectionComponent) { - this._tableSelectionComponent.refresh(); + if (this._tableSelectionComponent && this._dataTable) { + await this._tableSelectionComponent.refresh(); + await this._dataTable.refresh(); + this.refreshComponents(); } - await this._dataTable?.refresh(); } catch (err) { this.showErrorMessage(constants.getErrorMessage(err)); } finally { @@ -106,6 +142,13 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { return this._dataTable?.data; } + private refreshComponents(): void { + if (this._formBuilder) { + this.removeComponents(this._formBuilder); + this.addComponents(this._formBuilder); + } + } + private async onTableSelected(): Promise { if (this._tableSelectionComponent?.data) { this.importTable = this._tableSelectionComponent?.data; @@ -113,6 +156,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { if (this._dataTable) { await this._dataTable.refresh(); } + this.refreshComponents(); } } diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts index ab98d9ad62..3216f12e76 100644 --- a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts +++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts @@ -25,6 +25,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< private _downloadedFile: ModelArtifact | undefined; private _onModelSelectionChanged: vscode.EventEmitter = new vscode.EventEmitter(); public readonly onModelSelectionChanged: vscode.Event = this._onModelSelectionChanged.event; + public isEmpty: boolean = false; /** * Creates new view @@ -149,7 +150,6 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< } } - /** * Returns the component */ @@ -176,6 +176,8 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< tableData = tableData.concat(models.map(model => this.createTableRow(model))); } + this.isEmpty = models === undefined || models.length === 0; + this._table.data = tableData; } this.onModelSelected(); @@ -275,7 +277,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< if (confirm) { await this.sendDataRequest(DeleteModelEventName, model); if (this.parent) { - await this.parent?.refresh(); + await this.parent.refresh(); } } } catch (error) { diff --git a/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts b/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts index 1f5ef99f13..7a74db20dc 100644 --- a/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts +++ b/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts @@ -108,11 +108,12 @@ export class ImportModelWizard extends ModelViewBase { } else { await this.importAzureModel(this.modelsViewData); } + this._apiWrapper.showInfoMessage(constants.modelRegisteredSuccessfully); await this.storeImportConfigTable(); - this.showInfoMessage(constants.modelRegisteredSuccessfully); + return true; } catch (error) { - this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`); + await this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`); return false; } } diff --git a/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts b/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts index c319fc1f1f..f39f68746a 100644 --- a/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts +++ b/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts @@ -20,6 +20,9 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, private _formBuilder: azdata.FormBuilder | undefined; public tableSelectionComponent: TableSelectionComponent | undefined; private _labelComponent: azdata.TextComponent | undefined; + private _descriptionComponent: azdata.TextComponent | undefined; + private _labelContainer: azdata.FlexContainer | undefined; + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { super(apiWrapper, parent.root, parent); @@ -33,23 +36,40 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, this._formBuilder = modelBuilder.formContainer(); this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, { editable: true, preSelected: true }); + this._descriptionComponent = modelBuilder.text().withProperties({ + width: 200 + }).component(); this._labelComponent = modelBuilder.text().withProperties({ width: 200 }).component(); - const container = modelBuilder.flexContainer().withLayout({ + this._labelContainer = modelBuilder.flexContainer().withLayout({ + flexFlow: 'column', width: 800, - height: '400px', + height: '300px', justifyContent: 'center' - }).withItems([ - this._labelComponent - ], { - CSSStyles: { - 'align-items': 'center', - 'padding-top': '30px', - 'font-size': '16px' - } }).component(); + this._labelContainer.addItem( + this._labelComponent + , { + CSSStyles: { + 'align-items': 'center', + 'padding-top': '10px', + 'padding-left': `${this.componentMaxLength}px`, + 'font-size': '16px' + } + }); + this._labelContainer.addItem( + this._descriptionComponent + , { + CSSStyles: { + 'align-items': 'center', + 'padding-top': '10px', + 'padding-left': `${this.componentMaxLength - 80}px`, + 'font-size': '13px' + } + }); + this.tableSelectionComponent.onSelectedChanged(async () => { await this.onTableSelected(); @@ -59,7 +79,7 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, this._formBuilder.addFormItem({ title: '', - component: container + component: this._labelContainer }); this._form = this._formBuilder.component(); return this._form; @@ -71,15 +91,24 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, } if (this.importTable && this._labelComponent) { - const validated = await this.verifyImportConfigTable(this.importTable); - if (validated) { - this._labelComponent.value = constants.modelSchemaIsAcceptedMessage; + if (!this.validateImportTableName()) { + this._labelComponent.value = constants.selectModelsTableMessage; } else { - this._labelComponent.value = constants.modelSchemaIsNotAcceptedMessage; + const validated = await this.verifyImportConfigTable(this.importTable); + if (validated) { + this._labelComponent.value = constants.modelSchemaIsAcceptedMessage; + } else { + this._labelComponent.value = constants.modelSchemaIsNotAcceptedMessage; + } } } } + private validateImportTableName(): boolean { + return this.importTable?.databaseName !== undefined && this.importTable?.databaseName !== constants.selectDatabaseTitle + && this.importTable?.tableName !== undefined && this.importTable?.tableName !== constants.selectTableTitle; + } + /** * Returns selected data */ @@ -116,7 +145,7 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, public async validate(): Promise { let validated = false; - if (this.data?.databaseName && this.data?.tableName) { + if (this.data && this.validateImportTableName()) { validated = true; validated = await this.verifyImportConfigTable(this.data); if (!validated) { diff --git a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts index 6600f2b032..1e8b7b8690 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts @@ -261,7 +261,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent x.name === modelParameter.name); + let column = values.find(x => x.name.toLocaleUpperCase() === modelParameter.name.toLocaleUpperCase()); if (!column) { column = values.length > 0 ? values[0] : undefined; } diff --git a/extensions/machine-learning/src/views/models/tableSelectionComponent.ts b/extensions/machine-learning/src/views/models/tableSelectionComponent.ts index 6c7b853566..be7a9a0a61 100644 --- a/extensions/machine-learning/src/views/models/tableSelectionComponent.ts +++ b/extensions/machine-learning/src/views/models/tableSelectionComponent.ts @@ -29,7 +29,11 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo private _dbTableComponent: azdata.FlexContainer | undefined; private tableMaxLength = this.componentMaxLength * 2 + 70; private _onSelectedChanged: vscode.EventEmitter = new vscode.EventEmitter(); + private _existingTableButton: azdata.RadioButtonComponent | undefined; + private _newTableButton: azdata.RadioButtonComponent | undefined; + private _newTableName: azdata.InputBoxComponent | undefined; private _existingTablesSelected: boolean = true; + public readonly onSelectedChanged: vscode.Event = this._onSelectedChanged.event; /** @@ -55,50 +59,46 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo await this.onDatabaseSelected(); }); - const existingTableButton = modelBuilder.radioButton().withProperties({ + this._existingTableButton = modelBuilder.radioButton().withProperties({ name: 'tableName', value: 'existing', label: 'Existing table', checked: true }).component(); - const newTableButton = modelBuilder.radioButton().withProperties({ + this._newTableButton = modelBuilder.radioButton().withProperties({ name: 'tableName', value: 'new', label: 'New table', checked: false }).component(); - const newTableName = modelBuilder.inputBox().withProperties({ + this._newTableName = modelBuilder.inputBox().withProperties({ width: this.componentMaxLength - 10, enabled: false }).component(); const group = modelBuilder.groupContainer().withItems([ - existingTableButton, + this._existingTableButton, this._tables, - newTableButton, - newTableName + this._newTableButton, + this._newTableName ], { CSSStyles: { 'padding-top': '5px' } }).component(); - existingTableButton.onDidClick(() => { - if (this._tables) { - this._tables.enabled = existingTableButton.checked; - } - newTableName.enabled = !existingTableButton.checked; - this._existingTablesSelected = existingTableButton.checked || false; + this._existingTableButton.onDidClick(() => { + this._existingTablesSelected = true; + this.refreshTableComponent(); }); - newTableButton.onDidClick(() => { - if (this._tables) { - this._tables.enabled = !newTableButton.checked; - } - newTableName.enabled = newTableButton.checked; - this._existingTablesSelected = existingTableButton.checked || false; + this._newTableButton.onDidClick(() => { + this._existingTablesSelected = false; + this.refreshTableComponent(); }); - newTableName.onTextChanged(async () => { - this._selectedTableName = newTableName.value || ''; - await this.onTableSelected(); + this._newTableName.onTextChanged(async () => { + if (this._newTableName) { + this._selectedTableName = this._newTableName.value || ''; + await this.onTableSelected(); + } }); this._tables.onValueChanged(async (value) => { @@ -192,7 +192,7 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo public async loadData(): Promise { this._dbNames = await this.listDatabaseNames(); let dbNames = this._dbNames; - if (!this._settings.preSelected && !this._dbNames.find(x => x === constants.selectDatabaseTitle)) { + if (!this._dbNames.find(x => x === constants.selectDatabaseTitle)) { dbNames = [constants.selectDatabaseTitle].concat(this._dbNames); } if (this._databases && dbNames && dbNames.length > 0) { @@ -216,35 +216,49 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo } private async onDatabaseSelected(): Promise { - if (this._existingTablesSelected) { - this._tableNames = await this.listTableNames(this.databaseName || ''); - let tableNames = this._tableNames; - - if (this._tableNames && !this._settings.preSelected && !this._tableNames.find(x => x.tableName === constants.selectTableTitle)) { - const firstRow: DatabaseTable = { tableName: constants.selectTableTitle, databaseName: '', schema: '' }; - tableNames = [firstRow].concat(this._tableNames); - } - - if (this._tables && tableNames && tableNames.length > 0) { - this._tables.values = tableNames.map(t => this.getTableFullName(t)); - if (this.importTable) { - const selectedTable = tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema); - if (selectedTable) { - this._selectedTableName = this.getTableFullName(selectedTable); - this._tables.value = this.getTableFullName(selectedTable); - } else { - this._selectedTableName = this._settings.editable ? this.getTableFullName(this.importTable) : this.getTableFullName(tableNames[0]); - } - } else { - this._selectedTableName = this.getTableFullName(tableNames[0]); - } - this._tables.value = this._selectedTableName; - } else if (this._tables) { - this._tables.values = []; - this._tables.value = ''; - } + this._tableNames = await this.listTableNames(this.databaseName || ''); + let tableNames = this._tableNames; + if (this._settings.editable && this._tables && this._existingTableButton && this._newTableButton && this._newTableName) { + this._existingTablesSelected = this._tableNames !== undefined && this._tableNames.length > 0; + this._newTableButton.checked = !this._existingTablesSelected; + this._existingTableButton.checked = this._existingTablesSelected; } + this.refreshTableComponent(); + + + if (this._tableNames && !this._tableNames.find(x => x.tableName === constants.selectTableTitle)) { + const firstRow: DatabaseTable = { tableName: constants.selectTableTitle, databaseName: '', schema: '' }; + tableNames = [firstRow].concat(this._tableNames); + } + + if (this._tables && tableNames && tableNames.length > 0) { + this._tables.values = tableNames.map(t => this.getTableFullName(t)); + if (this.importTable && this.importTable.databaseName === this._databases?.value) { + const selectedTable = tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema); + if (selectedTable) { + this._selectedTableName = this.getTableFullName(selectedTable); + this._tables.value = this.getTableFullName(selectedTable); + } else { + this._selectedTableName = this._settings.editable ? this.getTableFullName(this.importTable) : this.getTableFullName(tableNames[0]); + } + } else { + this._selectedTableName = this.getTableFullName(tableNames[0]); + } + this._tables.value = this._selectedTableName; + } else if (this._tables) { + this._tables.values = []; + this._tables.value = ''; + } + await this.onTableSelected(); + + } + + private refreshTableComponent(): void { + if (this._settings.editable && this._tables && this._existingTableButton && this._newTableButton && this._newTableName) { + this._tables.enabled = this._existingTablesSelected; + this._newTableName.enabled = !this._existingTablesSelected; + } } private getTableFullName(table: DatabaseTable): string { diff --git a/extensions/machine-learning/src/views/widgets/dashboardWidget.ts b/extensions/machine-learning/src/views/widgets/dashboardWidget.ts index 8455425d32..f8732441ab 100644 --- a/extensions/machine-learning/src/views/widgets/dashboardWidget.ts +++ b/extensions/machine-learning/src/views/widgets/dashboardWidget.ts @@ -492,7 +492,9 @@ export class DashboardWidget { 'padding': '10px' } }); - predictionButton.enabled = await this._predictService.serverSupportOnnxModel(); + if (!await this._predictService.serverSupportOnnxModel()) { + console.log(constants.onnxNotSupportedError); + } return tasksContainer; }