diff --git a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts index bc5417256a..6790285dee 100644 --- a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts +++ b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts @@ -64,13 +64,13 @@ export class ModelPythonClient { 'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double', 'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`, `type_map = { - onnx.TensorProto.DataType.FLOAT: 'real', - onnx.TensorProto.DataType.UINT8: 'tinyint', - onnx.TensorProto.DataType.INT16: 'smallint', - onnx.TensorProto.DataType.INT32: 'int', - onnx.TensorProto.DataType.INT64: 'bigint', - onnx.TensorProto.DataType.STRING: 'varchar(MAX)', - onnx.TensorProto.DataType.DOUBLE: 'float'}`, + onnx.TensorProto.DataType.FLOAT: 'REAL', + onnx.TensorProto.DataType.UINT8: 'TINYINT', + onnx.TensorProto.DataType.INT16: 'SMALLINT', + onnx.TensorProto.DataType.INT32: 'INT', + onnx.TensorProto.DataType.INT64: 'BIGINT', + onnx.TensorProto.DataType.STRING: 'VARCHAR(MAX)', + onnx.TensorProto.DataType.DOUBLE: 'FLOAT'}`, `parameters = { "inputs": [], "outputs": [] diff --git a/extensions/machine-learning/src/prediction/predictService.ts b/extensions/machine-learning/src/prediction/predictService.ts index 37bfb27554..27cd5b118c 100644 --- a/extensions/machine-learning/src/prediction/predictService.ts +++ b/extensions/machine-learning/src/prediction/predictService.ts @@ -9,7 +9,8 @@ import { ApiWrapper } from '../common/apiWrapper'; import { QueryRunner } from '../common/queryRunner'; import * as utils from '../common/utils'; import { ImportedModel } from '../modelManagement/interfaces'; -import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces'; +import { PredictParameters, DatabaseTable, TableColumn } from '../prediction/interfaces'; +import * as queries from './queries'; /** * Service to make prediction @@ -67,7 +68,7 @@ export class PredictService { let connection = await this.getCurrentConnection(); let query = ''; if (registeredModel && registeredModel.id) { - query = this.getPredictScriptWithModelId( + query = queries.getPredictScriptWithModelId( registeredModel.id, predictParams.inputColumns || [], predictParams.outputColumns || [], @@ -75,7 +76,7 @@ export class PredictService { registeredModel.table); } else if (filePath) { let modelBytes = await utils.readFileInHex(filePath || ''); - query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [], + query = queries.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [], predictParams.outputColumns || [], predictParams); } @@ -97,7 +98,7 @@ export class PredictService { let connection = await this.getCurrentConnection(); let list: DatabaseTable[] = []; if (connection) { - let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName)); + let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, queries.getTablesScript(databaseName)); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { result.rows.forEach(row => { @@ -120,13 +121,13 @@ export class PredictService { let connection = await this.getCurrentConnection(); let list: TableColumn[] = []; if (connection && databaseTable.databaseName) { - const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable)); + const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, queries.getTableColumnsScript(databaseTable)); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { result.rows.forEach(row => { list.push({ columnName: row[0].displayValue, - dataType: row[1].displayValue + dataType: row[1].displayValue.toLocaleUpperCase() }); }); } @@ -137,112 +138,5 @@ export class PredictService { private async getCurrentConnection(): Promise { return await this._apiWrapper.getCurrentConnection(); } - - private getTableColumnsScript(databaseTable: DatabaseTable): string { - return ` -SELECT COLUMN_NAME,DATA_TYPE -FROM INFORMATION_SCHEMA.COLUMNS -WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}' -AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}' -AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}' - `; - } - - private getTablesScript(databaseName: string): string { - return ` -SELECT TABLE_NAME,TABLE_SCHEMA -FROM INFORMATION_SCHEMA.TABLES -WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}' - `; - } - - private getPredictScriptWithModelId( - modelId: number, - columns: PredictColumn[], - outputColumns: PredictColumn[], - sourceTable: DatabaseTable, - importTable: DatabaseTable): string { - const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || ''); - return ` -DECLARE @model VARBINARY(max) = ( - SELECT model - FROM ${threePartTableName} - WHERE model_id = ${modelId} -); -WITH predict_input -AS ( - SELECT TOP 1000 - ${this.getInputColumnNames(columns, 'pi')} - FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi -) -SELECT -${this.getPredictColumnNames(columns, 'predict_input')}, -${this.getPredictInputColumnNames(outputColumns, 'p')} -FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx) -WITH ( - ${this.getOutputParameters(outputColumns)} -) AS p -`; - } - - private getPredictScriptWithModelBytes( - modelBytes: string, - columns: PredictColumn[], - outputColumns: PredictColumn[], - databaseNameTable: DatabaseTable): string { - return ` -WITH predict_input -AS ( - SELECT TOP 1000 - ${this.getInputColumnNames(columns, 'pi')} - FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi -) -SELECT -${this.getPredictColumnNames(columns, 'predict_input')}, -${this.getPredictInputColumnNames(outputColumns, 'p')} -FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx) -WITH ( - ${this.getOutputParameters(outputColumns)} -) AS p -`; - } - - private getEscapedColumnName(tableName: string, columnName: string): string { - return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`; - } - private getInputColumnNames(columns: PredictColumn[], tableName: string) { - - return columns.map(c => { - const column = this.getEscapedColumnName(tableName, c.columnName); - let columnName = c.dataType !== c.paramType ? `cast(${column} as ${c.paramType})` - : `${column}`; - return `${columnName} AS ${c.paramName}`; - }).join(',\n'); - } - - private getPredictInputColumnNames(columns: PredictColumn[], tableName: string) { - return columns.map(c => { - return this.getColumnName(tableName, c.paramName || '', c.columnName); - }).join(',\n'); - } - - private getColumnName(tableName: string, columnName: string, displayName: string) { - const column = this.getEscapedColumnName(tableName, columnName); - return columnName && columnName !== displayName ? - `${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column; - } - - private getPredictColumnNames(columns: PredictColumn[], tableName: string) { - return columns.map(c => { - return c.paramName ? `${this.getEscapedColumnName(tableName, c.paramName)}` - : `${this.getEscapedColumnName(tableName, c.columnName)}`; - }).join(',\n'); - } - - private getOutputParameters(columns: PredictColumn[]) { - return columns.map(c => { - return `${c.paramName} ${c.dataType}`; - }).join(',\n'); - } } diff --git a/extensions/machine-learning/src/prediction/queries.ts b/extensions/machine-learning/src/prediction/queries.ts new file mode 100644 index 0000000000..42ea3109f9 --- /dev/null +++ b/extensions/machine-learning/src/prediction/queries.ts @@ -0,0 +1,114 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import * as utils from '../common/utils'; +import { PredictColumn, DatabaseTable } from './interfaces'; + +export function getTableColumnsScript(databaseTable: DatabaseTable): string { + return ` +SELECT COLUMN_NAME,DATA_TYPE +FROM INFORMATION_SCHEMA.COLUMNS +WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}' +AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}' +AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}' + `; +} + +export function getTablesScript(databaseName: string): string { + return ` +SELECT TABLE_NAME,TABLE_SCHEMA +FROM INFORMATION_SCHEMA.TABLES +WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}' + `; +} + +export function getPredictScriptWithModelId( + modelId: number, + columns: PredictColumn[], + outputColumns: PredictColumn[], + sourceTable: DatabaseTable, + importTable: DatabaseTable): string { + const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || ''); + return ` +DECLARE @model VARBINARY(max) = ( +SELECT model +FROM ${threePartTableName} +WHERE model_id = ${modelId} +); +WITH predict_input +AS ( + SELECT TOP 1000 + ${getInputColumnNames(columns, 'pi')} +FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] AS pi +) +SELECT +${getPredictColumnNames(columns, 'predict_input')}, +${getPredictInputColumnNames(outputColumns, 'p')} +FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx) +WITH ( +${getOutputParameters(outputColumns)} +) AS p +`; +} + +export function getPredictScriptWithModelBytes( + modelBytes: string, + columns: PredictColumn[], + outputColumns: PredictColumn[], + databaseNameTable: DatabaseTable): string { + return ` +WITH predict_input +AS ( + SELECT TOP 1000 + ${getInputColumnNames(columns, 'pi')} +FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] AS pi +) +SELECT +${getPredictColumnNames(columns, 'predict_input')}, +${getPredictInputColumnNames(outputColumns, 'p')} +FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx) +WITH ( +${getOutputParameters(outputColumns)} +) AS p +`; +} + +export function getEscapedColumnName(tableName: string, columnName: string): string { + return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`; +} +export function getInputColumnNames(columns: PredictColumn[], tableName: string) { + + return columns.map(c => { + const column = getEscapedColumnName(tableName, c.columnName); + let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${c.paramType})` + : `${column}`; + return `${columnName} AS ${c.paramName}`; + }).join(',\n '); +} + +export function getPredictInputColumnNames(columns: PredictColumn[], tableName: string) { + return columns.map(c => { + return getColumnName(tableName, c.paramName || '', c.columnName); + }).join(',\n '); +} + +export function getColumnName(tableName: string, columnName: string, displayName: string) { + const column = getEscapedColumnName(tableName, columnName); + return columnName && columnName !== displayName ? + `${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column; +} + +export function getPredictColumnNames(columns: PredictColumn[], tableName: string) { + return columns.map(c => { + return c.paramName ? `${getEscapedColumnName(tableName, c.paramName)}` + : `${getEscapedColumnName(tableName, c.columnName)}`; + }).join(',\n'); +} + +export function getOutputParameters(columns: PredictColumn[]) { + return columns.map(c => { + return `${c.paramName} ${c.dataType}`; + }).join(',\n'); +} diff --git a/extensions/machine-learning/src/test/prediction/predictService.test.ts b/extensions/machine-learning/src/test/prediction/predictService.test.ts index 3a2effe744..83d9cad0af 100644 --- a/extensions/machine-learning/src/test/prediction/predictService.test.ts +++ b/extensions/machine-learning/src/test/prediction/predictService.test.ts @@ -114,11 +114,11 @@ describe('PredictService', () => { const expected: TableColumn[] = [ { columnName: 'c1', - dataType: 'int' + dataType: 'INT' }, { columnName: 'c2', - dataType: 'varchar' + dataType: 'VARCHAR' } ]; const table: DatabaseTable = diff --git a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts index 1e8b7b8690..88354239cd 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts @@ -20,13 +20,13 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent