mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-05 09:35:39 -05:00
ML - Bug fixing (#11920)
* Fixed a bug with validating inputs when generating predict script * Fixed the bug with verifying R packages * Fixed the tests * Added warning for when output column data type doesn't match with model output data type * Fix the issue with selecting db
This commit is contained in:
@@ -56,7 +56,7 @@ export function cannotFindR(path: string): string { return localize('mls.cannotF
|
||||
export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Verifying package management dependencies");
|
||||
export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Verifying model management dependencies");
|
||||
export const noResultError = localize('mls.noResultError', "No Result returned");
|
||||
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
|
||||
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required packages are not installed");
|
||||
export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that.");
|
||||
export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script.");
|
||||
export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action.");
|
||||
@@ -65,6 +65,10 @@ export function confirmInstallPythonPackagesDetails(packages: string): string {
|
||||
return localize('mls.installDependencies.confirmInstallPythonPackages'
|
||||
, "The following Python packages are required to install: {0}", packages);
|
||||
}
|
||||
export function confirmInstallRPackagesDetails(packages: string): string {
|
||||
return localize('mls.installDependencies.confirmInstallRPackages'
|
||||
, "The following R packages are required to install: {0}", packages);
|
||||
}
|
||||
export function confirmDeleteModel(modelName: string): string {
|
||||
return localize('models.confirmDeleteModel'
|
||||
, "Are you sure you want to delete model '{0}?", modelName);
|
||||
@@ -207,6 +211,7 @@ export const invalidModelParametersError = localize('models.invalidModelParamete
|
||||
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
||||
export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table");
|
||||
export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMismatchWarning', "The data type of the source table column does not match the required input field’s type.");
|
||||
export const outputColumnDataTypeNotSupportedWarning = localize('models.outputColumnDataTypeNotSupportedWarning', "The data type of output column does not match the output field’s type.");
|
||||
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||
export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!");
|
||||
|
||||
@@ -71,12 +71,10 @@ export class PackageManager {
|
||||
// Only execute the command if there's a valid connection with ml configuration enabled
|
||||
//
|
||||
let connection = await this.getCurrentConnection();
|
||||
let isPythonInstalled = await this._service.isPythonInstalled(connection);
|
||||
let isRInstalled = await this._service.isRInstalled(connection);
|
||||
let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined;
|
||||
if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) {
|
||||
if (connection && await this._sqlPythonPackagePackageManager.canUseProvider()) {
|
||||
defaultProvider = this._sqlPythonPackagePackageManager;
|
||||
} else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) {
|
||||
} else if (connection && await this._sqlRPackageManager.canUseProvider()) {
|
||||
defaultProvider = this._sqlRPackageManager;
|
||||
}
|
||||
if (connection && defaultProvider) {
|
||||
@@ -119,7 +117,7 @@ export class PackageManager {
|
||||
public async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [
|
||||
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
|
||||
this.installRequiredRPackages()], true);
|
||||
this.installRequiredRPackages()], false);
|
||||
|
||||
await this.verifyOdbcInstalled();
|
||||
}
|
||||
@@ -135,12 +133,30 @@ export class PackageManager {
|
||||
|
||||
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
|
||||
const packages = this._config.requiredSqlRPackages.filter(p => !p.platform || p.platform === process.platform);
|
||||
let packagesToInstall: PackageConfigModel[] = [];
|
||||
|
||||
// Installs packages in order of listed in the config. The order specifies the dependency of the packages and
|
||||
// packages cannot install as parallel because of the dependency for each other
|
||||
for (let index = 0; index < packages.length; index++) {
|
||||
const packageName = packages[index];
|
||||
await this.installRPackage(packageName);
|
||||
const packageDetail = packages[index];
|
||||
const isInstalled = await this.verifyRPackageInstalled(packageDetail.name);
|
||||
if (!isInstalled) {
|
||||
packagesToInstall.push(packageDetail);
|
||||
}
|
||||
}
|
||||
|
||||
if (packagesToInstall.length > 0) {
|
||||
this._apiWrapper.showInfoMessage(constants.confirmInstallRPackagesDetails(packagesToInstall.map(x => x.name).join(', ')));
|
||||
let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages, this._apiWrapper);
|
||||
if (confirmed) {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackages);
|
||||
// Installs packages in order of listed in the config. The order specifies the dependency of the packages and
|
||||
// packages cannot install as parallel because of the dependency for each other
|
||||
for (let index = 0; index < packagesToInstall.length; index++) {
|
||||
const packageName = packagesToInstall[index];
|
||||
await this.installRPackage(packageName);
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.requiredPackagesNotInstalled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,6 +270,23 @@ export class PackageManager {
|
||||
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
}
|
||||
|
||||
private async verifyRPackageInstalled(packageName: string): Promise<boolean> {
|
||||
let rExecutable = await this.getRExecutable();
|
||||
|
||||
let scripts: string[] = [
|
||||
'formals(quit)$save <- formals(q)$save <- "no"',
|
||||
`library(${packageName})`,
|
||||
'q()'
|
||||
];
|
||||
|
||||
try {
|
||||
await this._processService.execScripts(rExecutable, scripts, ['--vanilla'], undefined);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private async installRPackage(model: PackageConfigModel): Promise<string> {
|
||||
let output = '';
|
||||
let cmd = '';
|
||||
|
||||
@@ -37,7 +37,7 @@ describe('Package Manager', () => {
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
|
||||
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
//testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.managePackages();
|
||||
@@ -52,7 +52,7 @@ describe('Package Manager', () => {
|
||||
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()));
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
|
||||
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(false);});
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
//testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.managePackages();
|
||||
@@ -72,18 +72,16 @@ describe('Package Manager', () => {
|
||||
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
|
||||
it('installDependencies Should download sqlmlutils if does not exist', async function (): Promise<void> {
|
||||
it('installDependencies Should download R sqlmlutils if does not exist', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"},
|
||||
{"name":"sqlmlutils","version":"1.1.1"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);});
|
||||
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), undefined)).returns(() => {return Promise.reject('');});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
|
||||
@@ -78,7 +78,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
|
||||
const data = this.data;
|
||||
const validated = data !== undefined && data.databaseName !== undefined && data.inputColumns !== undefined && data.outputColumns !== undefined
|
||||
&& data.tableName !== undefined && data.databaseName !== constants.selectDatabaseTitle && data.tableName !== constants.selectTableTitle
|
||||
&& !data.inputColumns.find(x => x.columnName === constants.selectColumnTitle);
|
||||
&& !data.inputColumns.find(x => (x.columnName === constants.selectColumnTitle) || !x.columnName);
|
||||
if (!validated) {
|
||||
this.showErrorMessage(constants.invalidModelParametersError);
|
||||
}
|
||||
|
||||
@@ -209,20 +209,44 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
|
||||
if (this._modelBuilder) {
|
||||
|
||||
let nameInput = this._modelBuilder.dropDown().withProperties({
|
||||
values: dataTypes,
|
||||
width: this.componentMaxLength
|
||||
const outputContainer = this._modelBuilder.flexContainer().withLayout({
|
||||
flexFlow: 'row',
|
||||
width: this.componentMaxLength + 20,
|
||||
justifyContent: 'flex-start'
|
||||
}).component();
|
||||
const warningButton = this.createWarningButton(constants.outputColumnDataTypeNotSupportedWarning);
|
||||
warningButton.onDidClick(() => {
|
||||
});
|
||||
const css = {
|
||||
'padding-top': '5px',
|
||||
'padding-right': '5px',
|
||||
'margin': '0px'
|
||||
};
|
||||
const name = modelParameter.name;
|
||||
const dataType = dataTypes.find(x => x === modelParameter.type);
|
||||
if (dataType) {
|
||||
nameInput.value = dataType;
|
||||
} else {
|
||||
let dataType = dataTypes.find(x => x === modelParameter.type);
|
||||
if (!dataType) {
|
||||
// Output type not supported
|
||||
//
|
||||
modelParameter.type = dataTypes[0];
|
||||
dataType = dataTypes[0];
|
||||
outputContainer.addItem(warningButton, {
|
||||
CSSStyles: css
|
||||
});
|
||||
}
|
||||
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
|
||||
let nameInput = this._modelBuilder.dropDown().withProperties({
|
||||
values: dataTypes,
|
||||
width: this.componentMaxLength,
|
||||
value: dataType
|
||||
}).component();
|
||||
outputContainer.addItem(nameInput, {
|
||||
CSSStyles: {
|
||||
'padding': '0px',
|
||||
'padding-right': '5px',
|
||||
'margin': '0px'
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
this._parameters.push({ columnName: name, paramName: name, dataType: dataType });
|
||||
|
||||
nameInput.onValueChanged(() => {
|
||||
const value = <string>nameInput.value;
|
||||
@@ -231,8 +255,14 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
if (selectedRow) {
|
||||
selectedRow.dataType = value;
|
||||
}
|
||||
outputContainer.addItem(warningButton, {
|
||||
CSSStyles: css
|
||||
});
|
||||
} else {
|
||||
outputContainer.removeItem(warningButton);
|
||||
}
|
||||
});
|
||||
|
||||
let displayNameInput = this._modelBuilder.inputBox().withProperties({
|
||||
value: name,
|
||||
width: 200
|
||||
@@ -243,7 +273,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
selectedRow.columnName = displayNameInput.value || name;
|
||||
}
|
||||
});
|
||||
return [`${name}(${modelParameter.originalType ? modelParameter.originalType : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
|
||||
return [`${name}(${modelParameter.originalType ? modelParameter.originalType : constants.unsupportedModelParameterType})`, displayNameInput, outputContainer];
|
||||
}
|
||||
|
||||
return [];
|
||||
@@ -276,7 +306,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
width: this.componentMaxLength + 20,
|
||||
justifyContent: 'flex-start'
|
||||
}).component();
|
||||
const warningButton = this.createWarningButton();
|
||||
const warningButton = this.createWarningButton(constants.columnDataTypeMismatchWarning);
|
||||
warningButton.onDidClick(() => {
|
||||
});
|
||||
|
||||
@@ -296,7 +326,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
}
|
||||
|
||||
const currentColumn = columns.find(x => x.columnName === value);
|
||||
if (currentColumn && modelParameter.type === currentColumn?.dataType) {
|
||||
if (currentColumn && modelParameter.type !== currentColumn?.dataType) {
|
||||
inputContainer.removeItem(warningButton);
|
||||
} else {
|
||||
inputContainer.addItem(warningButton, {
|
||||
@@ -341,11 +371,11 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
return [];
|
||||
}
|
||||
|
||||
private createWarningButton(): azdata.ButtonComponent {
|
||||
private createWarningButton(message: string): azdata.ButtonComponent {
|
||||
const warningButton = this._modelBuilder.button().withProperties({
|
||||
width: '10px',
|
||||
height: '10px',
|
||||
title: constants.columnDataTypeMismatchWarning,
|
||||
title: message,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/warning_notification_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/warning_notification.svg'),
|
||||
|
||||
@@ -165,11 +165,13 @@ export class InstalledPackagesTab {
|
||||
*/
|
||||
public static async getLocationComponent(view: azdata.ModelView, dialog: ManagePackagesDialog): Promise<azdata.Component> {
|
||||
const locations = await dialog.model.getLocations();
|
||||
let location: string;
|
||||
let component: azdata.Component;
|
||||
if (locations && locations.length === 1) {
|
||||
component = view.modelBuilder.text().withProperties({
|
||||
value: locations[0].displayName
|
||||
}).component();
|
||||
location = locations[0].name;
|
||||
} else if (locations && locations.length > 1) {
|
||||
let dropdownValues = locations.map(x => {
|
||||
return {
|
||||
@@ -179,6 +181,7 @@ export class InstalledPackagesTab {
|
||||
});
|
||||
const currentLocation = await dialog.model.getCurrentLocation();
|
||||
const selectedLocation = dropdownValues.find(x => x.name === currentLocation);
|
||||
location = currentLocation || locations[0].name;
|
||||
let locationDropDown = view.modelBuilder.dropDown().withProperties({
|
||||
values: dropdownValues,
|
||||
value: selectedLocation || dropdownValues[0]
|
||||
@@ -198,8 +201,8 @@ export class InstalledPackagesTab {
|
||||
component = view.modelBuilder.text().withProperties({
|
||||
}).component();
|
||||
}
|
||||
if (locations && locations.length > 0) {
|
||||
dialog.changeLocation(locations[0].name);
|
||||
if (location) {
|
||||
dialog.changeLocation(location);
|
||||
}
|
||||
return component;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user