ML - Bug fixing (#11920)

* Fixed a bug with validating inputs when generating predict script

* Fixed the bug with verifying R packages

* Fixed the tests

* Added warning for when output column data type doesn't match with model output data type

* Fix the issue with selecting db
This commit is contained in:
Leila Lali
2020-08-31 20:02:30 -07:00
committed by GitHub
parent 23bd05ea68
commit 635da9a2b2
6 changed files with 105 additions and 36 deletions

View File

@@ -56,7 +56,7 @@ export function cannotFindR(path: string): string { return localize('mls.cannotF
export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Verifying package management dependencies");
export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Verifying model management dependencies");
export const noResultError = localize('mls.noResultError', "No Result returned");
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required packages are not installed");
export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that.");
export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script.");
export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action.");
@@ -65,6 +65,10 @@ export function confirmInstallPythonPackagesDetails(packages: string): string {
return localize('mls.installDependencies.confirmInstallPythonPackages'
, "The following Python packages are required to install: {0}", packages);
}
export function confirmInstallRPackagesDetails(packages: string): string {
return localize('mls.installDependencies.confirmInstallRPackages'
, "The following R packages are required to install: {0}", packages);
}
export function confirmDeleteModel(modelName: string): string {
return localize('models.confirmDeleteModel'
, "Are you sure you want to delete model '{0}?", modelName);
@@ -207,6 +211,7 @@ export const invalidModelParametersError = localize('models.invalidModelParamete
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table");
export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMismatchWarning', "The data type of the source table column does not match the required input fields type.");
export const outputColumnDataTypeNotSupportedWarning = localize('models.outputColumnDataTypeNotSupportedWarning', "The data type of output column does not match the output fields type.");
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!");

View File

@@ -71,12 +71,10 @@ export class PackageManager {
// Only execute the command if there's a valid connection with ml configuration enabled
//
let connection = await this.getCurrentConnection();
let isPythonInstalled = await this._service.isPythonInstalled(connection);
let isRInstalled = await this._service.isRInstalled(connection);
let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined;
if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) {
if (connection && await this._sqlPythonPackagePackageManager.canUseProvider()) {
defaultProvider = this._sqlPythonPackagePackageManager;
} else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) {
} else if (connection && await this._sqlRPackageManager.canUseProvider()) {
defaultProvider = this._sqlRPackageManager;
}
if (connection && defaultProvider) {
@@ -119,7 +117,7 @@ export class PackageManager {
public async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
this.installRequiredRPackages()], true);
this.installRequiredRPackages()], false);
await this.verifyOdbcInstalled();
}
@@ -135,12 +133,30 @@ export class PackageManager {
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
const packages = this._config.requiredSqlRPackages.filter(p => !p.platform || p.platform === process.platform);
let packagesToInstall: PackageConfigModel[] = [];
// Installs packages in order of listed in the config. The order specifies the dependency of the packages and
// packages cannot install as parallel because of the dependency for each other
for (let index = 0; index < packages.length; index++) {
const packageName = packages[index];
await this.installRPackage(packageName);
const packageDetail = packages[index];
const isInstalled = await this.verifyRPackageInstalled(packageDetail.name);
if (!isInstalled) {
packagesToInstall.push(packageDetail);
}
}
if (packagesToInstall.length > 0) {
this._apiWrapper.showInfoMessage(constants.confirmInstallRPackagesDetails(packagesToInstall.map(x => x.name).join(', ')));
let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages, this._apiWrapper);
if (confirmed) {
this._outputChannel.appendLine(constants.installDependenciesPackages);
// Installs packages in order of listed in the config. The order specifies the dependency of the packages and
// packages cannot install as parallel because of the dependency for each other
for (let index = 0; index < packagesToInstall.length; index++) {
const packageName = packagesToInstall[index];
await this.installRPackage(packageName);
}
} else {
throw Error(constants.requiredPackagesNotInstalled);
}
}
}
@@ -254,6 +270,23 @@ export class PackageManager {
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
}
private async verifyRPackageInstalled(packageName: string): Promise<boolean> {
let rExecutable = await this.getRExecutable();
let scripts: string[] = [
'formals(quit)$save <- formals(q)$save <- "no"',
`library(${packageName})`,
'q()'
];
try {
await this._processService.execScripts(rExecutable, scripts, ['--vanilla'], undefined);
return true;
} catch {
return false;
}
}
private async installRPackage(model: PackageConfigModel): Promise<string> {
let output = '';
let cmd = '';

View File

@@ -37,7 +37,7 @@ describe('Package Manager', () => {
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
//testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
@@ -52,7 +52,7 @@ describe('Package Manager', () => {
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()));
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
//testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
@@ -72,18 +72,16 @@ describe('Package Manager', () => {
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('installDependencies Should download sqlmlutils if does not exist', async function (): Promise<void> {
it('installDependencies Should download R sqlmlutils if does not exist', async function (): Promise<void> {
let testContext = createContext();
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"},
{"name":"sqlmlutils","version":"1.1.1"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);});
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), undefined)).returns(() => {return Promise.reject('');});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();

View File

@@ -78,7 +78,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
const data = this.data;
const validated = data !== undefined && data.databaseName !== undefined && data.inputColumns !== undefined && data.outputColumns !== undefined
&& data.tableName !== undefined && data.databaseName !== constants.selectDatabaseTitle && data.tableName !== constants.selectTableTitle
&& !data.inputColumns.find(x => x.columnName === constants.selectColumnTitle);
&& !data.inputColumns.find(x => (x.columnName === constants.selectColumnTitle) || !x.columnName);
if (!validated) {
this.showErrorMessage(constants.invalidModelParametersError);
}

View File

@@ -209,20 +209,44 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
if (this._modelBuilder) {
let nameInput = this._modelBuilder.dropDown().withProperties({
values: dataTypes,
width: this.componentMaxLength
const outputContainer = this._modelBuilder.flexContainer().withLayout({
flexFlow: 'row',
width: this.componentMaxLength + 20,
justifyContent: 'flex-start'
}).component();
const warningButton = this.createWarningButton(constants.outputColumnDataTypeNotSupportedWarning);
warningButton.onDidClick(() => {
});
const css = {
'padding-top': '5px',
'padding-right': '5px',
'margin': '0px'
};
const name = modelParameter.name;
const dataType = dataTypes.find(x => x === modelParameter.type);
if (dataType) {
nameInput.value = dataType;
} else {
let dataType = dataTypes.find(x => x === modelParameter.type);
if (!dataType) {
// Output type not supported
//
modelParameter.type = dataTypes[0];
dataType = dataTypes[0];
outputContainer.addItem(warningButton, {
CSSStyles: css
});
}
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
let nameInput = this._modelBuilder.dropDown().withProperties({
values: dataTypes,
width: this.componentMaxLength,
value: dataType
}).component();
outputContainer.addItem(nameInput, {
CSSStyles: {
'padding': '0px',
'padding-right': '5px',
'margin': '0px'
}
});
this._parameters.push({ columnName: name, paramName: name, dataType: dataType });
nameInput.onValueChanged(() => {
const value = <string>nameInput.value;
@@ -231,8 +255,14 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
if (selectedRow) {
selectedRow.dataType = value;
}
outputContainer.addItem(warningButton, {
CSSStyles: css
});
} else {
outputContainer.removeItem(warningButton);
}
});
let displayNameInput = this._modelBuilder.inputBox().withProperties({
value: name,
width: 200
@@ -243,7 +273,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
selectedRow.columnName = displayNameInput.value || name;
}
});
return [`${name}(${modelParameter.originalType ? modelParameter.originalType : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
return [`${name}(${modelParameter.originalType ? modelParameter.originalType : constants.unsupportedModelParameterType})`, displayNameInput, outputContainer];
}
return [];
@@ -276,7 +306,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
width: this.componentMaxLength + 20,
justifyContent: 'flex-start'
}).component();
const warningButton = this.createWarningButton();
const warningButton = this.createWarningButton(constants.columnDataTypeMismatchWarning);
warningButton.onDidClick(() => {
});
@@ -296,7 +326,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
}
const currentColumn = columns.find(x => x.columnName === value);
if (currentColumn && modelParameter.type === currentColumn?.dataType) {
if (currentColumn && modelParameter.type !== currentColumn?.dataType) {
inputContainer.removeItem(warningButton);
} else {
inputContainer.addItem(warningButton, {
@@ -341,11 +371,11 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
return [];
}
private createWarningButton(): azdata.ButtonComponent {
private createWarningButton(message: string): azdata.ButtonComponent {
const warningButton = this._modelBuilder.button().withProperties({
width: '10px',
height: '10px',
title: constants.columnDataTypeMismatchWarning,
title: message,
iconPath: {
dark: this.asAbsolutePath('images/dark/warning_notification_inverse.svg'),
light: this.asAbsolutePath('images/light/warning_notification.svg'),

View File

@@ -165,11 +165,13 @@ export class InstalledPackagesTab {
*/
public static async getLocationComponent(view: azdata.ModelView, dialog: ManagePackagesDialog): Promise<azdata.Component> {
const locations = await dialog.model.getLocations();
let location: string;
let component: azdata.Component;
if (locations && locations.length === 1) {
component = view.modelBuilder.text().withProperties({
value: locations[0].displayName
}).component();
location = locations[0].name;
} else if (locations && locations.length > 1) {
let dropdownValues = locations.map(x => {
return {
@@ -179,6 +181,7 @@ export class InstalledPackagesTab {
});
const currentLocation = await dialog.model.getCurrentLocation();
const selectedLocation = dropdownValues.find(x => x.name === currentLocation);
location = currentLocation || locations[0].name;
let locationDropDown = view.modelBuilder.dropDown().withProperties({
values: dropdownValues,
value: selectedLocation || dropdownValues[0]
@@ -198,8 +201,8 @@ export class InstalledPackagesTab {
component = view.modelBuilder.text().withProperties({
}).component();
}
if (locations && locations.length > 0) {
dialog.changeLocation(locations[0].name);
if (location) {
dialog.changeLocation(location);
}
return component;
}