From eec6f64d62f4535d877745c4466ff25e40cf985b Mon Sep 17 00:00:00 2001 From: Leila Lali Date: Mon, 26 Oct 2020 17:36:37 -0700 Subject: [PATCH] ML - Bug fixing (#13018) * Fixing couple of bugs --- .../machine-learning/src/common/constants.ts | 17 ++++++ .../src/controllers/mainController.ts | 18 ++++-- .../src/modelManagement/modelPythonClient.ts | 11 ++-- .../src/prediction/interfaces.ts | 1 + .../src/prediction/predictService.ts | 3 +- .../src/prediction/queries.ts | 12 ++-- .../test/prediction/predictService.test.ts | 56 +++++++++++++++++-- .../test/views/models/predictWizard.test.ts | 20 ++++--- .../views/models/prediction/columnsTable.ts | 21 +++---- .../prediction/inputColumnsComponent.ts | 6 +- 10 files changed, 121 insertions(+), 44 deletions(-) diff --git a/extensions/machine-learning/src/common/constants.ts b/extensions/machine-learning/src/common/constants.ts index 061d115d0f..08b0d84194 100644 --- a/extensions/machine-learning/src/common/constants.ts +++ b/extensions/machine-learning/src/common/constants.ts @@ -44,11 +44,28 @@ export const registeredModelsTableName = 'registeredModelsTableName'; export const rPathConfigKey = 'rPath'; export const adsPythonBundleVersion = '0.0.1'; +// TSQL +// + +// The data types that are supported to convert model's parameters to SQL data +export const supportedDataTypes = [ + 'BIGINT', + 'INT', + 'SMALLINT', + 'REAL', + 'FLOAT', + 'VARCHAR(MAX)', + 'BIT' +]; +export const varcharMax = 'VARCHAR(MAX)'; +export const varcharDefaultLength = 100; + // Localized texts // export const msgYes = localize('msgYes', "Yes"); export const msgNo = localize('msgNo', "No"); export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed."); +export const notebookExtensionFailedError = localize('notebookExtensionFailedError', "The extension failed to load because of it's dependency to Notebook extension. Please check the output log for Notebook extension to get more details"); export const verifyOdbcDriverError = localize('mls.verifyOdbcDriverError.error', "'{0}' is required for package management. Please make sure it is installed and set up correctly.", supportedODBCDriver); export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); } export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); } diff --git a/extensions/machine-learning/src/controllers/mainController.ts b/extensions/machine-learning/src/controllers/mainController.ts index f3fe8e5977..6a68f99f5c 100644 --- a/extensions/machine-learning/src/controllers/mainController.ts +++ b/extensions/machine-learning/src/controllers/mainController.ts @@ -62,12 +62,18 @@ export default class MainController implements vscode.Disposable { * Returns an instance of Server Installation from notebook extension */ private async getNotebookExtensionApis(): Promise { - let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName); - if (nbExtension) { - await nbExtension.activate(); - return (nbExtension.exports as nbExtensionApis.IExtensionApi); - } else { - throw new Error(constants.notebookExtensionNotLoaded); + try { + let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName); + if (nbExtension) { + await nbExtension.activate(); + return (nbExtension.exports as nbExtensionApis.IExtensionApi); + } else { + throw new Error(constants.notebookExtensionNotLoaded); + } + } catch (err) { + this._outputChannel.appendLine(constants.notebookExtensionFailedError); + this._apiWrapper.showErrorMessage(constants.notebookExtensionFailedError); + throw err; } } diff --git a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts index 6790285dee..dfa3158ba2 100644 --- a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts +++ b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts @@ -82,11 +82,12 @@ export class ModelPythonClient { if value in type_map: p_type = type_map[value] name = type_list[value] - parameters[paramType].append({ - 'name': p.name, - 'type': p_type, - 'originalType': name - })`, + if name != 'undefined': + parameters[paramType].append({ + 'name': p.name, + 'type': p_type, + 'originalType': name + })`, 'addParameters(onnx_model.graph.input, "inputs")', 'addParameters(onnx_model.graph.output, "outputs")', diff --git a/extensions/machine-learning/src/prediction/interfaces.ts b/extensions/machine-learning/src/prediction/interfaces.ts index f95995d5b5..d201224ce6 100644 --- a/extensions/machine-learning/src/prediction/interfaces.ts +++ b/extensions/machine-learning/src/prediction/interfaces.ts @@ -6,6 +6,7 @@ export interface TableColumn { columnName: string; dataType?: string; + maxLength?: number; } export interface PredictColumn extends TableColumn { diff --git a/extensions/machine-learning/src/prediction/predictService.ts b/extensions/machine-learning/src/prediction/predictService.ts index 27cd5b118c..2bcbad8e14 100644 --- a/extensions/machine-learning/src/prediction/predictService.ts +++ b/extensions/machine-learning/src/prediction/predictService.ts @@ -127,7 +127,8 @@ export class PredictService { result.rows.forEach(row => { list.push({ columnName: row[0].displayValue, - dataType: row[1].displayValue.toLocaleUpperCase() + dataType: row[1].displayValue.toLocaleUpperCase(), + maxLength: row[2].isNull ? undefined : +row[2].displayValue.toLocaleUpperCase() }); }); } diff --git a/extensions/machine-learning/src/prediction/queries.ts b/extensions/machine-learning/src/prediction/queries.ts index 42ea3109f9..61be7aa6db 100644 --- a/extensions/machine-learning/src/prediction/queries.ts +++ b/extensions/machine-learning/src/prediction/queries.ts @@ -5,10 +5,11 @@ import * as utils from '../common/utils'; import { PredictColumn, DatabaseTable } from './interfaces'; +import * as constants from '../common/constants'; export function getTableColumnsScript(databaseTable: DatabaseTable): string { return ` -SELECT COLUMN_NAME,DATA_TYPE +SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}' AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}' @@ -57,13 +58,13 @@ export function getPredictScriptWithModelBytes( modelBytes: string, columns: PredictColumn[], outputColumns: PredictColumn[], - databaseNameTable: DatabaseTable): string { + sourceTable: DatabaseTable): string { return ` WITH predict_input AS ( SELECT TOP 1000 ${getInputColumnNames(columns, 'pi')} -FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] AS pi +FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] AS pi ) SELECT ${getPredictColumnNames(columns, 'predict_input')}, @@ -78,11 +79,14 @@ ${getOutputParameters(outputColumns)} export function getEscapedColumnName(tableName: string, columnName: string): string { return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`; } + export function getInputColumnNames(columns: PredictColumn[], tableName: string) { return columns.map(c => { const column = getEscapedColumnName(tableName, c.columnName); - let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${c.paramType})` + const maxLength = c.maxLength !== undefined ? c.maxLength : constants.varcharDefaultLength; + let paramType = c.paramType === constants.varcharMax ? `VARCHAR(${maxLength})` : c.paramType; + let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${paramType})` : `${column}`; return `${columnName} AS ${c.paramName}`; }).join(',\n '); diff --git a/extensions/machine-learning/src/test/prediction/predictService.test.ts b/extensions/machine-learning/src/test/prediction/predictService.test.ts index 83d9cad0af..5959cb360d 100644 --- a/extensions/machine-learning/src/test/prediction/predictService.test.ts +++ b/extensions/machine-learning/src/test/prediction/predictService.test.ts @@ -6,12 +6,13 @@ import * as azdata from 'azdata'; import * as vscode from 'vscode'; import { ApiWrapper } from '../../common/apiWrapper'; +import * as Queries from '../../prediction/queries'; import * as TypeMoq from 'typemoq'; import * as should from 'should'; import { PredictService } from '../../prediction/predictService'; import { QueryRunner } from '../../common/queryRunner'; import { ImportedModel } from '../../modelManagement/interfaces'; -import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; +import { PredictParameters, DatabaseTable, TableColumn, PredictColumn } from '../../prediction/interfaces'; import * as path from 'path'; import * as os from 'os'; import * as UUID from 'vscode-languageclient/lib/utils/uuid'; @@ -114,11 +115,13 @@ describe('PredictService', () => { const expected: TableColumn[] = [ { columnName: 'c1', - dataType: 'INT' + dataType: 'INT', + maxLength: undefined }, { columnName: 'c2', - dataType: 'VARCHAR' + dataType: 'VARCHAR', + maxLength: 10 } ]; const table: DatabaseTable = @@ -141,6 +144,11 @@ describe('PredictService', () => { displayValue: 'int', isNull: false, invariantCultureDisplayValue: '' + }, + { + displayValue: '', + isNull: true, + invariantCultureDisplayValue: '' } ], [ { @@ -152,6 +160,11 @@ describe('PredictService', () => { displayValue: 'varchar', isNull: false, invariantCultureDisplayValue: '' + }, + { + displayValue: '10', + isNull: false, + invariantCultureDisplayValue: '' } ]] }; @@ -175,12 +188,14 @@ describe('PredictService', () => { { paramName: 'p1', dataType: 'int', - columnName: '' + columnName: '', + maxLength: undefined }, { paramName: 'p2', dataType: 'varchar', - columnName: '' + columnName: '', + maxLength: 10 } ], outputColumns: [ @@ -298,4 +313,35 @@ describe('PredictService', () => { should.notEqual(actual, undefined); should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true); }); + + it('getInputColumnNames should user column max length for varchar type', async function (): Promise { + const columns: PredictColumn[] = [ + { + paramName: 'p1', + paramType: 'VARCHAR(MAX)', + columnName: 'c1', + dataType: 'VARCHAR', + maxLength: 20 + }, + { + paramName: 'p2', + paramType: 'VARCHAR(MAX)', + columnName: 'c2', + dataType: 'DATETIME', + maxLength: undefined + }, + { + paramName: 'p3', + paramType: 'INT', + columnName: 'c2', + dataType: 'INT', + maxLength: undefined + }, + ]; + + const tableName = 'tbname'; + let actual = Queries.getInputColumnNames(columns, tableName); + let expected =`CAST([tbname].[c1] AS VARCHAR(20)) AS p1,\n\tCAST([tbname].[c2] AS VARCHAR(100)) AS p2,\n\t[tbname].[c2] AS p3`; + should.deepEqual(actual, expected); + }); }); diff --git a/extensions/machine-learning/src/test/views/models/predictWizard.test.ts b/extensions/machine-learning/src/test/views/models/predictWizard.test.ts index 0637b2483e..c3b683f220 100644 --- a/extensions/machine-learning/src/test/views/models/predictWizard.test.ts +++ b/extensions/machine-learning/src/test/views/models/predictWizard.test.ts @@ -180,7 +180,7 @@ describe('Predict Wizard', () => { view.modelBrowsePage.modelSourceType = ModelSourceType.Azure; } await view.refresh(); - should.notEqual(view.azureModelsComponent?.data, undefined); + should.notEqual(view.azureModelsComponent?.data, undefined, 'Data from Azure component should not be null'); if (view.modelBrowsePage) { view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels; @@ -188,19 +188,21 @@ describe('Predict Wizard', () => { await view.refresh(); testContext.onClick.fire(undefined); - should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels); - should.notEqual(view.localModelsComponent?.data, undefined); - should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined); + + should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels, 'Model source should be registered models'); + should.notEqual(view.localModelsComponent?.data, undefined, 'Data from local model component should not be null'); + should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined, 'Data from registered model component should not be null'); if (view.modelBrowsePage?.registeredModelsComponent?.data) { - should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1); + should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1, 'Data from registered model component should not be empty'); } - should.notEqual(await view.getModelFileName(), undefined); + should.notEqual(await view.getModelFileName(), undefined, 'Model file name should not be null'); await view.columnsSelectionPage?.onEnter(); + await view.columnsSelectionPage?.inputColumnsComponent?.loadWithTable(tableNames[0]); - should.notEqual(view.columnsSelectionPage?.data, undefined); - should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name); - should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length); + should.notEqual(view.columnsSelectionPage?.data, undefined, 'Data from column selection component should not be null'); + should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, `unexpected number of inputs. ${view.columnsSelectionPage?.data?.inputColumns?.length}` ); + should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length, `unexpected number of outputs. ${view.columnsSelectionPage?.data?.outputColumns?.length}`); }); }); diff --git a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts index 218fe6ab9e..bf359f9970 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts @@ -19,15 +19,6 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent this.createOutputTableRow(output, this._dataTypes))); + if (modelParameters?.outputs && constants.supportedDataTypes) { + tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, constants.supportedDataTypes))); } } @@ -299,7 +290,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent x.paramName === name); if (selectedRow) { selectedRow.columnName = value || ''; + let tableColumn = columns.find(x => x.columnName === value); + if (tableColumn) { + selectedRow.maxLength = tableColumn.maxLength; + } } const currentColumn = columns.find(x => x.columnName === value); diff --git a/extensions/machine-learning/src/views/models/prediction/inputColumnsComponent.ts b/extensions/machine-learning/src/views/models/prediction/inputColumnsComponent.ts index c34506508f..5c10f76697 100644 --- a/extensions/machine-learning/src/views/models/prediction/inputColumnsComponent.ts +++ b/extensions/machine-learning/src/views/models/prediction/inputColumnsComponent.ts @@ -132,7 +132,11 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone } private async onTableSelected(): Promise { - this._columns?.loadInputs(this._modelParameters, this.databaseTable); + await this.loadWithTable(this.databaseTable); + } + + public async loadWithTable(table: DatabaseTable): Promise { + await this._columns?.loadInputs(this._modelParameters, table); } private get databaseTable(): DatabaseTable {