ML - Bug fixing (#13018)

* Fixing couple of bugs
This commit is contained in:
Leila Lali
2020-10-26 17:36:37 -07:00
committed by GitHub
parent 20ed569a71
commit eec6f64d62
10 changed files with 121 additions and 44 deletions

View File

@@ -44,11 +44,28 @@ export const registeredModelsTableName = 'registeredModelsTableName';
export const rPathConfigKey = 'rPath'; export const rPathConfigKey = 'rPath';
export const adsPythonBundleVersion = '0.0.1'; export const adsPythonBundleVersion = '0.0.1';
// TSQL
//
// The data types that are supported to convert model's parameters to SQL data
export const supportedDataTypes = [
'BIGINT',
'INT',
'SMALLINT',
'REAL',
'FLOAT',
'VARCHAR(MAX)',
'BIT'
];
export const varcharMax = 'VARCHAR(MAX)';
export const varcharDefaultLength = 100;
// Localized texts // Localized texts
// //
export const msgYes = localize('msgYes', "Yes"); export const msgYes = localize('msgYes', "Yes");
export const msgNo = localize('msgNo', "No"); export const msgNo = localize('msgNo', "No");
export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed."); export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed.");
export const notebookExtensionFailedError = localize('notebookExtensionFailedError', "The extension failed to load because of it's dependency to Notebook extension. Please check the output log for Notebook extension to get more details");
export const verifyOdbcDriverError = localize('mls.verifyOdbcDriverError.error', "'{0}' is required for package management. Please make sure it is installed and set up correctly.", supportedODBCDriver); export const verifyOdbcDriverError = localize('mls.verifyOdbcDriverError.error', "'{0}' is required for package management. Please make sure it is installed and set up correctly.", supportedODBCDriver);
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); } export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); } export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); }

View File

@@ -62,12 +62,18 @@ export default class MainController implements vscode.Disposable {
* Returns an instance of Server Installation from notebook extension * Returns an instance of Server Installation from notebook extension
*/ */
private async getNotebookExtensionApis(): Promise<nbExtensionApis.IExtensionApi> { private async getNotebookExtensionApis(): Promise<nbExtensionApis.IExtensionApi> {
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName); try {
if (nbExtension) { let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
await nbExtension.activate(); if (nbExtension) {
return (nbExtension.exports as nbExtensionApis.IExtensionApi); await nbExtension.activate();
} else { return (nbExtension.exports as nbExtensionApis.IExtensionApi);
throw new Error(constants.notebookExtensionNotLoaded); } else {
throw new Error(constants.notebookExtensionNotLoaded);
}
} catch (err) {
this._outputChannel.appendLine(constants.notebookExtensionFailedError);
this._apiWrapper.showErrorMessage(constants.notebookExtensionFailedError);
throw err;
} }
} }

View File

@@ -82,11 +82,12 @@ export class ModelPythonClient {
if value in type_map: if value in type_map:
p_type = type_map[value] p_type = type_map[value]
name = type_list[value] name = type_list[value]
parameters[paramType].append({ if name != 'undefined':
'name': p.name, parameters[paramType].append({
'type': p_type, 'name': p.name,
'originalType': name 'type': p_type,
})`, 'originalType': name
})`,
'addParameters(onnx_model.graph.input, "inputs")', 'addParameters(onnx_model.graph.input, "inputs")',
'addParameters(onnx_model.graph.output, "outputs")', 'addParameters(onnx_model.graph.output, "outputs")',

View File

@@ -6,6 +6,7 @@
export interface TableColumn { export interface TableColumn {
columnName: string; columnName: string;
dataType?: string; dataType?: string;
maxLength?: number;
} }
export interface PredictColumn extends TableColumn { export interface PredictColumn extends TableColumn {

View File

@@ -127,7 +127,8 @@ export class PredictService {
result.rows.forEach(row => { result.rows.forEach(row => {
list.push({ list.push({
columnName: row[0].displayValue, columnName: row[0].displayValue,
dataType: row[1].displayValue.toLocaleUpperCase() dataType: row[1].displayValue.toLocaleUpperCase(),
maxLength: row[2].isNull ? undefined : +row[2].displayValue.toLocaleUpperCase()
}); });
}); });
} }

View File

@@ -5,10 +5,11 @@
import * as utils from '../common/utils'; import * as utils from '../common/utils';
import { PredictColumn, DatabaseTable } from './interfaces'; import { PredictColumn, DatabaseTable } from './interfaces';
import * as constants from '../common/constants';
export function getTableColumnsScript(databaseTable: DatabaseTable): string { export function getTableColumnsScript(databaseTable: DatabaseTable): string {
return ` return `
SELECT COLUMN_NAME,DATA_TYPE SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH
FROM INFORMATION_SCHEMA.COLUMNS FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}' WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}' AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
@@ -57,13 +58,13 @@ export function getPredictScriptWithModelBytes(
modelBytes: string, modelBytes: string,
columns: PredictColumn[], columns: PredictColumn[],
outputColumns: PredictColumn[], outputColumns: PredictColumn[],
databaseNameTable: DatabaseTable): string { sourceTable: DatabaseTable): string {
return ` return `
WITH predict_input WITH predict_input
AS ( AS (
SELECT TOP 1000 SELECT TOP 1000
${getInputColumnNames(columns, 'pi')} ${getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] AS pi FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] AS pi
) )
SELECT SELECT
${getPredictColumnNames(columns, 'predict_input')}, ${getPredictColumnNames(columns, 'predict_input')},
@@ -78,11 +79,14 @@ ${getOutputParameters(outputColumns)}
export function getEscapedColumnName(tableName: string, columnName: string): string { export function getEscapedColumnName(tableName: string, columnName: string): string {
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`; return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
} }
export function getInputColumnNames(columns: PredictColumn[], tableName: string) { export function getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => { return columns.map(c => {
const column = getEscapedColumnName(tableName, c.columnName); const column = getEscapedColumnName(tableName, c.columnName);
let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${c.paramType})` const maxLength = c.maxLength !== undefined ? c.maxLength : constants.varcharDefaultLength;
let paramType = c.paramType === constants.varcharMax ? `VARCHAR(${maxLength})` : c.paramType;
let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${paramType})`
: `${column}`; : `${column}`;
return `${columnName} AS ${c.paramName}`; return `${columnName} AS ${c.paramName}`;
}).join(',\n '); }).join(',\n ');

View File

@@ -6,12 +6,13 @@
import * as azdata from 'azdata'; import * as azdata from 'azdata';
import * as vscode from 'vscode'; import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper'; import { ApiWrapper } from '../../common/apiWrapper';
import * as Queries from '../../prediction/queries';
import * as TypeMoq from 'typemoq'; import * as TypeMoq from 'typemoq';
import * as should from 'should'; import * as should from 'should';
import { PredictService } from '../../prediction/predictService'; import { PredictService } from '../../prediction/predictService';
import { QueryRunner } from '../../common/queryRunner'; import { QueryRunner } from '../../common/queryRunner';
import { ImportedModel } from '../../modelManagement/interfaces'; import { ImportedModel } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces'; import { PredictParameters, DatabaseTable, TableColumn, PredictColumn } from '../../prediction/interfaces';
import * as path from 'path'; import * as path from 'path';
import * as os from 'os'; import * as os from 'os';
import * as UUID from 'vscode-languageclient/lib/utils/uuid'; import * as UUID from 'vscode-languageclient/lib/utils/uuid';
@@ -114,11 +115,13 @@ describe('PredictService', () => {
const expected: TableColumn[] = [ const expected: TableColumn[] = [
{ {
columnName: 'c1', columnName: 'c1',
dataType: 'INT' dataType: 'INT',
maxLength: undefined
}, },
{ {
columnName: 'c2', columnName: 'c2',
dataType: 'VARCHAR' dataType: 'VARCHAR',
maxLength: 10
} }
]; ];
const table: DatabaseTable = const table: DatabaseTable =
@@ -141,6 +144,11 @@ describe('PredictService', () => {
displayValue: 'int', displayValue: 'int',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
},
{
displayValue: '',
isNull: true,
invariantCultureDisplayValue: ''
} }
], [ ], [
{ {
@@ -152,6 +160,11 @@ describe('PredictService', () => {
displayValue: 'varchar', displayValue: 'varchar',
isNull: false, isNull: false,
invariantCultureDisplayValue: '' invariantCultureDisplayValue: ''
},
{
displayValue: '10',
isNull: false,
invariantCultureDisplayValue: ''
} }
]] ]]
}; };
@@ -175,12 +188,14 @@ describe('PredictService', () => {
{ {
paramName: 'p1', paramName: 'p1',
dataType: 'int', dataType: 'int',
columnName: '' columnName: '',
maxLength: undefined
}, },
{ {
paramName: 'p2', paramName: 'p2',
dataType: 'varchar', dataType: 'varchar',
columnName: '' columnName: '',
maxLength: 10
} }
], ],
outputColumns: [ outputColumns: [
@@ -298,4 +313,35 @@ describe('PredictService', () => {
should.notEqual(actual, undefined); should.notEqual(actual, undefined);
should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true); should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
}); });
it('getInputColumnNames should user column max length for varchar type', async function (): Promise<void> {
const columns: PredictColumn[] = [
{
paramName: 'p1',
paramType: 'VARCHAR(MAX)',
columnName: 'c1',
dataType: 'VARCHAR',
maxLength: 20
},
{
paramName: 'p2',
paramType: 'VARCHAR(MAX)',
columnName: 'c2',
dataType: 'DATETIME',
maxLength: undefined
},
{
paramName: 'p3',
paramType: 'INT',
columnName: 'c2',
dataType: 'INT',
maxLength: undefined
},
];
const tableName = 'tbname';
let actual = Queries.getInputColumnNames(columns, tableName);
let expected =`CAST([tbname].[c1] AS VARCHAR(20)) AS p1,\n\tCAST([tbname].[c2] AS VARCHAR(100)) AS p2,\n\t[tbname].[c2] AS p3`;
should.deepEqual(actual, expected);
});
}); });

View File

@@ -180,7 +180,7 @@ describe('Predict Wizard', () => {
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure; view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
} }
await view.refresh(); await view.refresh();
should.notEqual(view.azureModelsComponent?.data, undefined); should.notEqual(view.azureModelsComponent?.data, undefined, 'Data from Azure component should not be null');
if (view.modelBrowsePage) { if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels; view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
@@ -188,19 +188,21 @@ describe('Predict Wizard', () => {
await view.refresh(); await view.refresh();
testContext.onClick.fire(undefined); testContext.onClick.fire(undefined);
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels);
should.notEqual(view.localModelsComponent?.data, undefined); should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels, 'Model source should be registered models');
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined); should.notEqual(view.localModelsComponent?.data, undefined, 'Data from local model component should not be null');
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined, 'Data from registered model component should not be null');
if (view.modelBrowsePage?.registeredModelsComponent?.data) { if (view.modelBrowsePage?.registeredModelsComponent?.data) {
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1); should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1, 'Data from registered model component should not be empty');
} }
should.notEqual(await view.getModelFileName(), undefined); should.notEqual(await view.getModelFileName(), undefined, 'Model file name should not be null');
await view.columnsSelectionPage?.onEnter(); await view.columnsSelectionPage?.onEnter();
await view.columnsSelectionPage?.inputColumnsComponent?.loadWithTable(tableNames[0]);
should.notEqual(view.columnsSelectionPage?.data, undefined); should.notEqual(view.columnsSelectionPage?.data, undefined, 'Data from column selection component should not be null');
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name); should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, `unexpected number of inputs. ${view.columnsSelectionPage?.data?.inputColumns?.length}` );
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length); should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length, `unexpected number of outputs. ${view.columnsSelectionPage?.data?.outputColumns?.length}`);
}); });
}); });

View File

@@ -19,15 +19,6 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
private _table: azdata.DeclarativeTableComponent | undefined; private _table: azdata.DeclarativeTableComponent | undefined;
private _parameters: PredictColumn[] = []; private _parameters: PredictColumn[] = [];
private _loader: azdata.LoadingComponent; private _loader: azdata.LoadingComponent;
private _dataTypes: string[] = [
'BIGINT',
'INT',
'SMALLINT',
'REAL',
'FLOAT',
'VARCHAR(MAX)',
'BIT'
];
/** /**
@@ -171,7 +162,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
this._parameters = []; this._parameters = [];
let tableData: any[][] = []; let tableData: any[][] = [];
if (this._table) { if (this._table && table && table.tableName !== constants.selectTableTitle) {
if (this._forInput) { if (this._forInput) {
let columns: TableColumn[]; let columns: TableColumn[];
try { try {
@@ -196,8 +187,8 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
if (this._table) { if (this._table) {
if (!this._forInput) { if (!this._forInput) {
if (modelParameters?.outputs && this._dataTypes) { if (modelParameters?.outputs && constants.supportedDataTypes) {
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes))); tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, constants.supportedDataTypes)));
} }
} }
@@ -299,7 +290,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
nameInput.value = column; nameInput.value = column;
if (column) { if (column) {
this._parameters.push({ columnName: column.name, paramName: name, paramType: modelParameter.type }); this._parameters.push({ columnName: column.name, paramName: name, paramType: modelParameter.type, maxLength: currentColumn?.maxLength });
} }
const inputContainer = this._modelBuilder.flexContainer().withLayout({ const inputContainer = this._modelBuilder.flexContainer().withLayout({
flexFlow: 'row', flexFlow: 'row',
@@ -323,6 +314,10 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
let selectedRow = this._parameters.find(x => x.paramName === name); let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) { if (selectedRow) {
selectedRow.columnName = value || ''; selectedRow.columnName = value || '';
let tableColumn = columns.find(x => x.columnName === value);
if (tableColumn) {
selectedRow.maxLength = tableColumn.maxLength;
}
} }
const currentColumn = columns.find(x => x.columnName === value); const currentColumn = columns.find(x => x.columnName === value);

View File

@@ -132,7 +132,11 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
} }
private async onTableSelected(): Promise<void> { private async onTableSelected(): Promise<void> {
this._columns?.loadInputs(this._modelParameters, this.databaseTable); await this.loadWithTable(this.databaseTable);
}
public async loadWithTable(table: DatabaseTable): Promise<void> {
await this._columns?.loadInputs(this._modelParameters, table);
} }
private get databaseTable(): DatabaseTable { private get databaseTable(): DatabaseTable {