ML - Bug fixing (#13018)

* Fixing couple of bugs
This commit is contained in:
Leila Lali
2020-10-26 17:36:37 -07:00
committed by GitHub
parent 20ed569a71
commit eec6f64d62
10 changed files with 121 additions and 44 deletions

View File

@@ -44,11 +44,28 @@ export const registeredModelsTableName = 'registeredModelsTableName';
export const rPathConfigKey = 'rPath';
export const adsPythonBundleVersion = '0.0.1';
// TSQL
//
// The data types that are supported to convert model's parameters to SQL data
export const supportedDataTypes = [
'BIGINT',
'INT',
'SMALLINT',
'REAL',
'FLOAT',
'VARCHAR(MAX)',
'BIT'
];
export const varcharMax = 'VARCHAR(MAX)';
export const varcharDefaultLength = 100;
// Localized texts
//
export const msgYes = localize('msgYes', "Yes");
export const msgNo = localize('msgNo', "No");
export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed.");
export const notebookExtensionFailedError = localize('notebookExtensionFailedError', "The extension failed to load because of it's dependency to Notebook extension. Please check the output log for Notebook extension to get more details");
export const verifyOdbcDriverError = localize('mls.verifyOdbcDriverError.error', "'{0}' is required for package management. Please make sure it is installed and set up correctly.", supportedODBCDriver);
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); }

View File

@@ -62,12 +62,18 @@ export default class MainController implements vscode.Disposable {
* Returns an instance of Server Installation from notebook extension
*/
private async getNotebookExtensionApis(): Promise<nbExtensionApis.IExtensionApi> {
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
if (nbExtension) {
await nbExtension.activate();
return (nbExtension.exports as nbExtensionApis.IExtensionApi);
} else {
throw new Error(constants.notebookExtensionNotLoaded);
try {
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
if (nbExtension) {
await nbExtension.activate();
return (nbExtension.exports as nbExtensionApis.IExtensionApi);
} else {
throw new Error(constants.notebookExtensionNotLoaded);
}
} catch (err) {
this._outputChannel.appendLine(constants.notebookExtensionFailedError);
this._apiWrapper.showErrorMessage(constants.notebookExtensionFailedError);
throw err;
}
}

View File

@@ -82,11 +82,12 @@ export class ModelPythonClient {
if value in type_map:
p_type = type_map[value]
name = type_list[value]
parameters[paramType].append({
'name': p.name,
'type': p_type,
'originalType': name
})`,
if name != 'undefined':
parameters[paramType].append({
'name': p.name,
'type': p_type,
'originalType': name
})`,
'addParameters(onnx_model.graph.input, "inputs")',
'addParameters(onnx_model.graph.output, "outputs")',

View File

@@ -6,6 +6,7 @@
export interface TableColumn {
columnName: string;
dataType?: string;
maxLength?: number;
}
export interface PredictColumn extends TableColumn {

View File

@@ -127,7 +127,8 @@ export class PredictService {
result.rows.forEach(row => {
list.push({
columnName: row[0].displayValue,
dataType: row[1].displayValue.toLocaleUpperCase()
dataType: row[1].displayValue.toLocaleUpperCase(),
maxLength: row[2].isNull ? undefined : +row[2].displayValue.toLocaleUpperCase()
});
});
}

View File

@@ -5,10 +5,11 @@
import * as utils from '../common/utils';
import { PredictColumn, DatabaseTable } from './interfaces';
import * as constants from '../common/constants';
export function getTableColumnsScript(databaseTable: DatabaseTable): string {
return `
SELECT COLUMN_NAME,DATA_TYPE
SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
@@ -57,13 +58,13 @@ export function getPredictScriptWithModelBytes(
modelBytes: string,
columns: PredictColumn[],
outputColumns: PredictColumn[],
databaseNameTable: DatabaseTable): string {
sourceTable: DatabaseTable): string {
return `
WITH predict_input
AS (
SELECT TOP 1000
${getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] AS pi
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] AS pi
)
SELECT
${getPredictColumnNames(columns, 'predict_input')},
@@ -78,11 +79,14 @@ ${getOutputParameters(outputColumns)}
export function getEscapedColumnName(tableName: string, columnName: string): string {
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
}
export function getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
const column = getEscapedColumnName(tableName, c.columnName);
let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${c.paramType})`
const maxLength = c.maxLength !== undefined ? c.maxLength : constants.varcharDefaultLength;
let paramType = c.paramType === constants.varcharMax ? `VARCHAR(${maxLength})` : c.paramType;
let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${paramType})`
: `${column}`;
return `${columnName} AS ${c.paramName}`;
}).join(',\n ');

View File

@@ -6,12 +6,13 @@
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper';
import * as Queries from '../../prediction/queries';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { PredictService } from '../../prediction/predictService';
import { QueryRunner } from '../../common/queryRunner';
import { ImportedModel } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { PredictParameters, DatabaseTable, TableColumn, PredictColumn } from '../../prediction/interfaces';
import * as path from 'path';
import * as os from 'os';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
@@ -114,11 +115,13 @@ describe('PredictService', () => {
const expected: TableColumn[] = [
{
columnName: 'c1',
dataType: 'INT'
dataType: 'INT',
maxLength: undefined
},
{
columnName: 'c2',
dataType: 'VARCHAR'
dataType: 'VARCHAR',
maxLength: 10
}
];
const table: DatabaseTable =
@@ -141,6 +144,11 @@ describe('PredictService', () => {
displayValue: 'int',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '',
isNull: true,
invariantCultureDisplayValue: ''
}
], [
{
@@ -152,6 +160,11 @@ describe('PredictService', () => {
displayValue: 'varchar',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '10',
isNull: false,
invariantCultureDisplayValue: ''
}
]]
};
@@ -175,12 +188,14 @@ describe('PredictService', () => {
{
paramName: 'p1',
dataType: 'int',
columnName: ''
columnName: '',
maxLength: undefined
},
{
paramName: 'p2',
dataType: 'varchar',
columnName: ''
columnName: '',
maxLength: 10
}
],
outputColumns: [
@@ -298,4 +313,35 @@ describe('PredictService', () => {
should.notEqual(actual, undefined);
should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
});
it('getInputColumnNames should user column max length for varchar type', async function (): Promise<void> {
const columns: PredictColumn[] = [
{
paramName: 'p1',
paramType: 'VARCHAR(MAX)',
columnName: 'c1',
dataType: 'VARCHAR',
maxLength: 20
},
{
paramName: 'p2',
paramType: 'VARCHAR(MAX)',
columnName: 'c2',
dataType: 'DATETIME',
maxLength: undefined
},
{
paramName: 'p3',
paramType: 'INT',
columnName: 'c2',
dataType: 'INT',
maxLength: undefined
},
];
const tableName = 'tbname';
let actual = Queries.getInputColumnNames(columns, tableName);
let expected =`CAST([tbname].[c1] AS VARCHAR(20)) AS p1,\n\tCAST([tbname].[c2] AS VARCHAR(100)) AS p2,\n\t[tbname].[c2] AS p3`;
should.deepEqual(actual, expected);
});
});

View File

@@ -180,7 +180,7 @@ describe('Predict Wizard', () => {
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
}
await view.refresh();
should.notEqual(view.azureModelsComponent?.data, undefined);
should.notEqual(view.azureModelsComponent?.data, undefined, 'Data from Azure component should not be null');
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
@@ -188,19 +188,21 @@ describe('Predict Wizard', () => {
await view.refresh();
testContext.onClick.fire(undefined);
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels);
should.notEqual(view.localModelsComponent?.data, undefined);
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined);
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels, 'Model source should be registered models');
should.notEqual(view.localModelsComponent?.data, undefined, 'Data from local model component should not be null');
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined, 'Data from registered model component should not be null');
if (view.modelBrowsePage?.registeredModelsComponent?.data) {
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1);
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1, 'Data from registered model component should not be empty');
}
should.notEqual(await view.getModelFileName(), undefined);
should.notEqual(await view.getModelFileName(), undefined, 'Model file name should not be null');
await view.columnsSelectionPage?.onEnter();
await view.columnsSelectionPage?.inputColumnsComponent?.loadWithTable(tableNames[0]);
should.notEqual(view.columnsSelectionPage?.data, undefined);
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name);
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length);
should.notEqual(view.columnsSelectionPage?.data, undefined, 'Data from column selection component should not be null');
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, `unexpected number of inputs. ${view.columnsSelectionPage?.data?.inputColumns?.length}` );
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length, `unexpected number of outputs. ${view.columnsSelectionPage?.data?.outputColumns?.length}`);
});
});

View File

@@ -19,15 +19,6 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
private _table: azdata.DeclarativeTableComponent | undefined;
private _parameters: PredictColumn[] = [];
private _loader: azdata.LoadingComponent;
private _dataTypes: string[] = [
'BIGINT',
'INT',
'SMALLINT',
'REAL',
'FLOAT',
'VARCHAR(MAX)',
'BIT'
];
/**
@@ -171,7 +162,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
this._parameters = [];
let tableData: any[][] = [];
if (this._table) {
if (this._table && table && table.tableName !== constants.selectTableTitle) {
if (this._forInput) {
let columns: TableColumn[];
try {
@@ -196,8 +187,8 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
if (this._table) {
if (!this._forInput) {
if (modelParameters?.outputs && this._dataTypes) {
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
if (modelParameters?.outputs && constants.supportedDataTypes) {
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, constants.supportedDataTypes)));
}
}
@@ -299,7 +290,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
nameInput.value = column;
if (column) {
this._parameters.push({ columnName: column.name, paramName: name, paramType: modelParameter.type });
this._parameters.push({ columnName: column.name, paramName: name, paramType: modelParameter.type, maxLength: currentColumn?.maxLength });
}
const inputContainer = this._modelBuilder.flexContainer().withLayout({
flexFlow: 'row',
@@ -323,6 +314,10 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.columnName = value || '';
let tableColumn = columns.find(x => x.columnName === value);
if (tableColumn) {
selectedRow.maxLength = tableColumn.maxLength;
}
}
const currentColumn = columns.find(x => x.columnName === value);

View File

@@ -132,7 +132,11 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
}
private async onTableSelected(): Promise<void> {
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
await this.loadWithTable(this.databaseTable);
}
public async loadWithTable(table: DatabaseTable): Promise<void> {
await this._columns?.loadInputs(this._modelParameters, table);
}
private get databaseTable(): DatabaseTable {