ML - Bug fixing (#11920)

* Fixed a bug with validating inputs when generating predict script

* Fixed the bug with verifying R packages

* Fixed the tests

* Added warning for when output column data type doesn't match with model output data type

* Fix the issue with selecting db
This commit is contained in:
Leila Lali
2020-08-31 20:02:30 -07:00
committed by GitHub
parent 23bd05ea68
commit 635da9a2b2
6 changed files with 105 additions and 36 deletions

View File

@@ -56,7 +56,7 @@ export function cannotFindR(path: string): string { return localize('mls.cannotF
export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Verifying package management dependencies"); export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Verifying package management dependencies");
export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Verifying model management dependencies"); export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Verifying model management dependencies");
export const noResultError = localize('mls.noResultError', "No Result returned"); export const noResultError = localize('mls.noResultError', "No Result returned");
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed"); export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required packages are not installed");
export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that."); export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that.");
export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script."); export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script.");
export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action."); export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action.");
@@ -65,6 +65,10 @@ export function confirmInstallPythonPackagesDetails(packages: string): string {
return localize('mls.installDependencies.confirmInstallPythonPackages' return localize('mls.installDependencies.confirmInstallPythonPackages'
, "The following Python packages are required to install: {0}", packages); , "The following Python packages are required to install: {0}", packages);
} }
export function confirmInstallRPackagesDetails(packages: string): string {
return localize('mls.installDependencies.confirmInstallRPackages'
, "The following R packages are required to install: {0}", packages);
}
export function confirmDeleteModel(modelName: string): string { export function confirmDeleteModel(modelName: string): string {
return localize('models.confirmDeleteModel' return localize('models.confirmDeleteModel'
, "Are you sure you want to delete model '{0}?", modelName); , "Are you sure you want to delete model '{0}?", modelName);
@@ -207,6 +211,7 @@ export const invalidModelParametersError = localize('models.invalidModelParamete
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model"); export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table"); export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table");
export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMismatchWarning', "The data type of the source table column does not match the required input fields type."); export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMismatchWarning', "The data type of the source table column does not match the required input fields type.");
export const outputColumnDataTypeNotSupportedWarning = localize('models.outputColumnDataTypeNotSupportedWarning', "The data type of output column does not match the output fields type.");
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!"); export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!");

View File

@@ -71,12 +71,10 @@ export class PackageManager {
// Only execute the command if there's a valid connection with ml configuration enabled // Only execute the command if there's a valid connection with ml configuration enabled
// //
let connection = await this.getCurrentConnection(); let connection = await this.getCurrentConnection();
let isPythonInstalled = await this._service.isPythonInstalled(connection);
let isRInstalled = await this._service.isRInstalled(connection);
let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined; let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined;
if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) { if (connection && await this._sqlPythonPackagePackageManager.canUseProvider()) {
defaultProvider = this._sqlPythonPackagePackageManager; defaultProvider = this._sqlPythonPackagePackageManager;
} else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) { } else if (connection && await this._sqlRPackageManager.canUseProvider()) {
defaultProvider = this._sqlRPackageManager; defaultProvider = this._sqlRPackageManager;
} }
if (connection && defaultProvider) { if (connection && defaultProvider) {
@@ -119,7 +117,7 @@ export class PackageManager {
public async installDependencies(): Promise<void> { public async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [ await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages), this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
this.installRequiredRPackages()], true); this.installRequiredRPackages()], false);
await this.verifyOdbcInstalled(); await this.verifyOdbcInstalled();
} }
@@ -135,13 +133,31 @@ export class PackageManager {
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder)); await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
const packages = this._config.requiredSqlRPackages.filter(p => !p.platform || p.platform === process.platform); const packages = this._config.requiredSqlRPackages.filter(p => !p.platform || p.platform === process.platform);
let packagesToInstall: PackageConfigModel[] = [];
for (let index = 0; index < packages.length; index++) {
const packageDetail = packages[index];
const isInstalled = await this.verifyRPackageInstalled(packageDetail.name);
if (!isInstalled) {
packagesToInstall.push(packageDetail);
}
}
if (packagesToInstall.length > 0) {
this._apiWrapper.showInfoMessage(constants.confirmInstallRPackagesDetails(packagesToInstall.map(x => x.name).join(', ')));
let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages, this._apiWrapper);
if (confirmed) {
this._outputChannel.appendLine(constants.installDependenciesPackages);
// Installs packages in order of listed in the config. The order specifies the dependency of the packages and // Installs packages in order of listed in the config. The order specifies the dependency of the packages and
// packages cannot install as parallel because of the dependency for each other // packages cannot install as parallel because of the dependency for each other
for (let index = 0; index < packages.length; index++) { for (let index = 0; index < packagesToInstall.length; index++) {
const packageName = packages[index]; const packageName = packagesToInstall[index];
await this.installRPackage(packageName); await this.installRPackage(packageName);
} }
} else {
throw Error(constants.requiredPackagesNotInstalled);
}
}
} }
/** /**
@@ -254,6 +270,23 @@ export class PackageManager {
return await this._processService.executeBufferedCommand(cmd, this._outputChannel); return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
} }
private async verifyRPackageInstalled(packageName: string): Promise<boolean> {
let rExecutable = await this.getRExecutable();
let scripts: string[] = [
'formals(quit)$save <- formals(q)$save <- "no"',
`library(${packageName})`,
'q()'
];
try {
await this._processService.execScripts(rExecutable, scripts, ['--vanilla'], undefined);
return true;
} catch {
return false;
}
}
private async installRPackage(model: PackageConfigModel): Promise<string> { private async installRPackage(model: PackageConfigModel): Promise<string> {
let output = ''; let output = '';
let cmd = ''; let cmd = '';

View File

@@ -37,7 +37,7 @@ describe('Package Manager', () => {
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();}); testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);}); testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(true);}); testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);}); //testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);}); testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext); let packageManager = createPackageManager(testContext);
await packageManager.managePackages(); await packageManager.managePackages();
@@ -52,7 +52,7 @@ describe('Package Manager', () => {
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny())); testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()));
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);}); testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(false);}); testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);}); //testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);}); testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext); let packageManager = createPackageManager(testContext);
await packageManager.managePackages(); await packageManager.managePackages();
@@ -72,18 +72,16 @@ describe('Package Manager', () => {
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once()); testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
}); });
it('installDependencies Should download sqlmlutils if does not exist', async function (): Promise<void> { it('installDependencies Should download R sqlmlutils if does not exist', async function (): Promise<void> {
let testContext = createContext(); let testContext = createContext();
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"},
{"name":"sqlmlutils","version":"1.1.1"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op); operationInfo.operation(testContext.op);
}); });
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);}); testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), undefined)).returns(() => {return Promise.reject('');});
let packageManager = createPackageManager(testContext); let packageManager = createPackageManager(testContext);
await packageManager.installDependencies(); await packageManager.installDependencies();

View File

@@ -78,7 +78,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
const data = this.data; const data = this.data;
const validated = data !== undefined && data.databaseName !== undefined && data.inputColumns !== undefined && data.outputColumns !== undefined const validated = data !== undefined && data.databaseName !== undefined && data.inputColumns !== undefined && data.outputColumns !== undefined
&& data.tableName !== undefined && data.databaseName !== constants.selectDatabaseTitle && data.tableName !== constants.selectTableTitle && data.tableName !== undefined && data.databaseName !== constants.selectDatabaseTitle && data.tableName !== constants.selectTableTitle
&& !data.inputColumns.find(x => x.columnName === constants.selectColumnTitle); && !data.inputColumns.find(x => (x.columnName === constants.selectColumnTitle) || !x.columnName);
if (!validated) { if (!validated) {
this.showErrorMessage(constants.invalidModelParametersError); this.showErrorMessage(constants.invalidModelParametersError);
} }

View File

@@ -209,20 +209,44 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] { private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
if (this._modelBuilder) { if (this._modelBuilder) {
let nameInput = this._modelBuilder.dropDown().withProperties({ const outputContainer = this._modelBuilder.flexContainer().withLayout({
values: dataTypes, flexFlow: 'row',
width: this.componentMaxLength width: this.componentMaxLength + 20,
justifyContent: 'flex-start'
}).component(); }).component();
const warningButton = this.createWarningButton(constants.outputColumnDataTypeNotSupportedWarning);
warningButton.onDidClick(() => {
});
const css = {
'padding-top': '5px',
'padding-right': '5px',
'margin': '0px'
};
const name = modelParameter.name; const name = modelParameter.name;
const dataType = dataTypes.find(x => x === modelParameter.type); let dataType = dataTypes.find(x => x === modelParameter.type);
if (dataType) { if (!dataType) {
nameInput.value = dataType;
} else {
// Output type not supported // Output type not supported
// //
modelParameter.type = dataTypes[0]; dataType = dataTypes[0];
outputContainer.addItem(warningButton, {
CSSStyles: css
});
} }
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type }); let nameInput = this._modelBuilder.dropDown().withProperties({
values: dataTypes,
width: this.componentMaxLength,
value: dataType
}).component();
outputContainer.addItem(nameInput, {
CSSStyles: {
'padding': '0px',
'padding-right': '5px',
'margin': '0px'
}
});
this._parameters.push({ columnName: name, paramName: name, dataType: dataType });
nameInput.onValueChanged(() => { nameInput.onValueChanged(() => {
const value = <string>nameInput.value; const value = <string>nameInput.value;
@@ -231,8 +255,14 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
if (selectedRow) { if (selectedRow) {
selectedRow.dataType = value; selectedRow.dataType = value;
} }
outputContainer.addItem(warningButton, {
CSSStyles: css
});
} else {
outputContainer.removeItem(warningButton);
} }
}); });
let displayNameInput = this._modelBuilder.inputBox().withProperties({ let displayNameInput = this._modelBuilder.inputBox().withProperties({
value: name, value: name,
width: 200 width: 200
@@ -243,7 +273,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
selectedRow.columnName = displayNameInput.value || name; selectedRow.columnName = displayNameInput.value || name;
} }
}); });
return [`${name}(${modelParameter.originalType ? modelParameter.originalType : constants.unsupportedModelParameterType})`, displayNameInput, nameInput]; return [`${name}(${modelParameter.originalType ? modelParameter.originalType : constants.unsupportedModelParameterType})`, displayNameInput, outputContainer];
} }
return []; return [];
@@ -276,7 +306,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
width: this.componentMaxLength + 20, width: this.componentMaxLength + 20,
justifyContent: 'flex-start' justifyContent: 'flex-start'
}).component(); }).component();
const warningButton = this.createWarningButton(); const warningButton = this.createWarningButton(constants.columnDataTypeMismatchWarning);
warningButton.onDidClick(() => { warningButton.onDidClick(() => {
}); });
@@ -296,7 +326,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
} }
const currentColumn = columns.find(x => x.columnName === value); const currentColumn = columns.find(x => x.columnName === value);
if (currentColumn && modelParameter.type === currentColumn?.dataType) { if (currentColumn && modelParameter.type !== currentColumn?.dataType) {
inputContainer.removeItem(warningButton); inputContainer.removeItem(warningButton);
} else { } else {
inputContainer.addItem(warningButton, { inputContainer.addItem(warningButton, {
@@ -341,11 +371,11 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
return []; return [];
} }
private createWarningButton(): azdata.ButtonComponent { private createWarningButton(message: string): azdata.ButtonComponent {
const warningButton = this._modelBuilder.button().withProperties({ const warningButton = this._modelBuilder.button().withProperties({
width: '10px', width: '10px',
height: '10px', height: '10px',
title: constants.columnDataTypeMismatchWarning, title: message,
iconPath: { iconPath: {
dark: this.asAbsolutePath('images/dark/warning_notification_inverse.svg'), dark: this.asAbsolutePath('images/dark/warning_notification_inverse.svg'),
light: this.asAbsolutePath('images/light/warning_notification.svg'), light: this.asAbsolutePath('images/light/warning_notification.svg'),

View File

@@ -165,11 +165,13 @@ export class InstalledPackagesTab {
*/ */
public static async getLocationComponent(view: azdata.ModelView, dialog: ManagePackagesDialog): Promise<azdata.Component> { public static async getLocationComponent(view: azdata.ModelView, dialog: ManagePackagesDialog): Promise<azdata.Component> {
const locations = await dialog.model.getLocations(); const locations = await dialog.model.getLocations();
let location: string;
let component: azdata.Component; let component: azdata.Component;
if (locations && locations.length === 1) { if (locations && locations.length === 1) {
component = view.modelBuilder.text().withProperties({ component = view.modelBuilder.text().withProperties({
value: locations[0].displayName value: locations[0].displayName
}).component(); }).component();
location = locations[0].name;
} else if (locations && locations.length > 1) { } else if (locations && locations.length > 1) {
let dropdownValues = locations.map(x => { let dropdownValues = locations.map(x => {
return { return {
@@ -179,6 +181,7 @@ export class InstalledPackagesTab {
}); });
const currentLocation = await dialog.model.getCurrentLocation(); const currentLocation = await dialog.model.getCurrentLocation();
const selectedLocation = dropdownValues.find(x => x.name === currentLocation); const selectedLocation = dropdownValues.find(x => x.name === currentLocation);
location = currentLocation || locations[0].name;
let locationDropDown = view.modelBuilder.dropDown().withProperties({ let locationDropDown = view.modelBuilder.dropDown().withProperties({
values: dropdownValues, values: dropdownValues,
value: selectedLocation || dropdownValues[0] value: selectedLocation || dropdownValues[0]
@@ -198,8 +201,8 @@ export class InstalledPackagesTab {
component = view.modelBuilder.text().withProperties({ component = view.modelBuilder.text().withProperties({
}).component(); }).component();
} }
if (locations && locations.length > 0) { if (location) {
dialog.changeLocation(locations[0].name); dialog.changeLocation(location);
} }
return component; return component;
} }