diff --git a/extensions/machine-learning/src/common/constants.ts b/extensions/machine-learning/src/common/constants.ts index 64531f9231..9541dfe561 100644 --- a/extensions/machine-learning/src/common/constants.ts +++ b/extensions/machine-learning/src/common/constants.ts @@ -56,7 +56,7 @@ export function cannotFindR(path: string): string { return localize('mls.cannotF export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Verifying package management dependencies"); export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Verifying model management dependencies"); export const noResultError = localize('mls.noResultError', "No Result returned"); -export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed"); +export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required packages are not installed"); export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that."); export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script."); export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action."); @@ -65,6 +65,10 @@ export function confirmInstallPythonPackagesDetails(packages: string): string { return localize('mls.installDependencies.confirmInstallPythonPackages' , "The following Python packages are required to install: {0}", packages); } +export function confirmInstallRPackagesDetails(packages: string): string { + return localize('mls.installDependencies.confirmInstallRPackages' + , "The following R packages are required to install: {0}", packages); +} export function confirmDeleteModel(modelName: string): string { return localize('models.confirmDeleteModel' , "Are you sure you want to delete model '{0}?", modelName); @@ -207,6 +211,7 @@ export const invalidModelParametersError = localize('models.invalidModelParamete export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model"); export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table"); export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMismatchWarning', "The data type of the source table column does not match the required input field’s type."); +export const outputColumnDataTypeNotSupportedWarning = localize('models.outputColumnDataTypeNotSupportedWarning', "The data type of output column does not match the output field’s type."); export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!"); diff --git a/extensions/machine-learning/src/packageManagement/packageManager.ts b/extensions/machine-learning/src/packageManagement/packageManager.ts index 60df14debb..ab1d9f3f7b 100644 --- a/extensions/machine-learning/src/packageManagement/packageManager.ts +++ b/extensions/machine-learning/src/packageManagement/packageManager.ts @@ -71,12 +71,10 @@ export class PackageManager { // Only execute the command if there's a valid connection with ml configuration enabled // let connection = await this.getCurrentConnection(); - let isPythonInstalled = await this._service.isPythonInstalled(connection); - let isRInstalled = await this._service.isRInstalled(connection); let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined; - if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) { + if (connection && await this._sqlPythonPackagePackageManager.canUseProvider()) { defaultProvider = this._sqlPythonPackagePackageManager; - } else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) { + } else if (connection && await this._sqlRPackageManager.canUseProvider()) { defaultProvider = this._sqlRPackageManager; } if (connection && defaultProvider) { @@ -119,7 +117,7 @@ export class PackageManager { public async installDependencies(): Promise { await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [ this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages), - this.installRequiredRPackages()], true); + this.installRequiredRPackages()], false); await this.verifyOdbcInstalled(); } @@ -135,12 +133,30 @@ export class PackageManager { await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder)); const packages = this._config.requiredSqlRPackages.filter(p => !p.platform || p.platform === process.platform); + let packagesToInstall: PackageConfigModel[] = []; - // Installs packages in order of listed in the config. The order specifies the dependency of the packages and - // packages cannot install as parallel because of the dependency for each other for (let index = 0; index < packages.length; index++) { - const packageName = packages[index]; - await this.installRPackage(packageName); + const packageDetail = packages[index]; + const isInstalled = await this.verifyRPackageInstalled(packageDetail.name); + if (!isInstalled) { + packagesToInstall.push(packageDetail); + } + } + + if (packagesToInstall.length > 0) { + this._apiWrapper.showInfoMessage(constants.confirmInstallRPackagesDetails(packagesToInstall.map(x => x.name).join(', '))); + let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages, this._apiWrapper); + if (confirmed) { + this._outputChannel.appendLine(constants.installDependenciesPackages); + // Installs packages in order of listed in the config. The order specifies the dependency of the packages and + // packages cannot install as parallel because of the dependency for each other + for (let index = 0; index < packagesToInstall.length; index++) { + const packageName = packagesToInstall[index]; + await this.installRPackage(packageName); + } + } else { + throw Error(constants.requiredPackagesNotInstalled); + } } } @@ -254,6 +270,23 @@ export class PackageManager { return await this._processService.executeBufferedCommand(cmd, this._outputChannel); } + private async verifyRPackageInstalled(packageName: string): Promise { + let rExecutable = await this.getRExecutable(); + + let scripts: string[] = [ + 'formals(quit)$save <- formals(q)$save <- "no"', + `library(${packageName})`, + 'q()' + ]; + + try { + await this._processService.execScripts(rExecutable, scripts, ['--vanilla'], undefined); + return true; + } catch { + return false; + } + } + private async installRPackage(model: PackageConfigModel): Promise { let output = ''; let cmd = ''; diff --git a/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts b/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts index 7e7edbc02f..40e0933a3e 100644 --- a/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts +++ b/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts @@ -37,7 +37,7 @@ describe('Package Manager', () => { testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();}); testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);}); testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(true);}); - testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);}); + //testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);}); testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);}); let packageManager = createPackageManager(testContext); await packageManager.managePackages(); @@ -52,7 +52,7 @@ describe('Package Manager', () => { testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny())); testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);}); testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(false);}); - testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);}); + //testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);}); testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);}); let packageManager = createPackageManager(testContext); await packageManager.managePackages(); @@ -72,18 +72,16 @@ describe('Package Manager', () => { testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); }); - it('installDependencies Should download sqlmlutils if does not exist', async function (): Promise { + it('installDependencies Should download R sqlmlutils if does not exist', async function (): Promise { let testContext = createContext(); - - let installedPackages = `[ - {"name":"pymssql","version":"2.1.4"}, - {"name":"sqlmlutils","version":"1.1.1"} - ]`; testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { operationInfo.operation(testContext.op); }); - testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);}); + testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ + label: 'Yes' + })); + testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), undefined)).returns(() => {return Promise.reject('');}); let packageManager = createPackageManager(testContext); await packageManager.installDependencies(); diff --git a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts index 106934304f..79f8029807 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts @@ -78,7 +78,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID const data = this.data; const validated = data !== undefined && data.databaseName !== undefined && data.inputColumns !== undefined && data.outputColumns !== undefined && data.tableName !== undefined && data.databaseName !== constants.selectDatabaseTitle && data.tableName !== constants.selectTableTitle - && !data.inputColumns.find(x => x.columnName === constants.selectColumnTitle); + && !data.inputColumns.find(x => (x.columnName === constants.selectColumnTitle) || !x.columnName); if (!validated) { this.showErrorMessage(constants.invalidModelParametersError); } diff --git a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts index 88354239cd..d843147431 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts @@ -209,20 +209,44 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent { + }); + const css = { + 'padding-top': '5px', + 'padding-right': '5px', + 'margin': '0px' + }; const name = modelParameter.name; - const dataType = dataTypes.find(x => x === modelParameter.type); - if (dataType) { - nameInput.value = dataType; - } else { + let dataType = dataTypes.find(x => x === modelParameter.type); + if (!dataType) { // Output type not supported // - modelParameter.type = dataTypes[0]; + dataType = dataTypes[0]; + outputContainer.addItem(warningButton, { + CSSStyles: css + }); } - this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type }); + let nameInput = this._modelBuilder.dropDown().withProperties({ + values: dataTypes, + width: this.componentMaxLength, + value: dataType + }).component(); + outputContainer.addItem(nameInput, { + CSSStyles: { + 'padding': '0px', + 'padding-right': '5px', + 'margin': '0px' + } + }); + + + this._parameters.push({ columnName: name, paramName: name, dataType: dataType }); nameInput.onValueChanged(() => { const value = nameInput.value; @@ -231,8 +255,14 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent { }); @@ -296,7 +326,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent x.columnName === value); - if (currentColumn && modelParameter.type === currentColumn?.dataType) { + if (currentColumn && modelParameter.type !== currentColumn?.dataType) { inputContainer.removeItem(warningButton); } else { inputContainer.addItem(warningButton, { @@ -341,11 +371,11 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent { const locations = await dialog.model.getLocations(); + let location: string; let component: azdata.Component; if (locations && locations.length === 1) { component = view.modelBuilder.text().withProperties({ value: locations[0].displayName }).component(); + location = locations[0].name; } else if (locations && locations.length > 1) { let dropdownValues = locations.map(x => { return { @@ -179,6 +181,7 @@ export class InstalledPackagesTab { }); const currentLocation = await dialog.model.getCurrentLocation(); const selectedLocation = dropdownValues.find(x => x.name === currentLocation); + location = currentLocation || locations[0].name; let locationDropDown = view.modelBuilder.dropDown().withProperties({ values: dropdownValues, value: selectedLocation || dropdownValues[0] @@ -198,8 +201,8 @@ export class InstalledPackagesTab { component = view.modelBuilder.text().withProperties({ }).component(); } - if (locations && locations.length > 0) { - dialog.changeLocation(locations[0].name); + if (location) { + dialog.changeLocation(location); } return component; }