ML - Fixed script formatting for prediction (#11767)

* Fixed script formatting for prediction
This commit is contained in:
Leila Lali
2020-08-12 13:36:39 -07:00
committed by GitHub
parent 094ee7c50c
commit e40a81e8e1
5 changed files with 137 additions and 129 deletions

View File

@@ -9,7 +9,8 @@ import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import * as utils from '../common/utils';
import { ImportedModel } from '../modelManagement/interfaces';
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../prediction/interfaces';
import * as queries from './queries';
/**
* Service to make prediction
@@ -67,7 +68,7 @@ export class PredictService {
let connection = await this.getCurrentConnection();
let query = '';
if (registeredModel && registeredModel.id) {
query = this.getPredictScriptWithModelId(
query = queries.getPredictScriptWithModelId(
registeredModel.id,
predictParams.inputColumns || [],
predictParams.outputColumns || [],
@@ -75,7 +76,7 @@ export class PredictService {
registeredModel.table);
} else if (filePath) {
let modelBytes = await utils.readFileInHex(filePath || '');
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
query = queries.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
predictParams.outputColumns || [],
predictParams);
}
@@ -97,7 +98,7 @@ export class PredictService {
let connection = await this.getCurrentConnection();
let list: DatabaseTable[] = [];
if (connection) {
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName));
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, queries.getTablesScript(databaseName));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
@@ -120,13 +121,13 @@ export class PredictService {
let connection = await this.getCurrentConnection();
let list: TableColumn[] = [];
if (connection && databaseTable.databaseName) {
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, queries.getTableColumnsScript(databaseTable));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push({
columnName: row[0].displayValue,
dataType: row[1].displayValue
dataType: row[1].displayValue.toLocaleUpperCase()
});
});
}
@@ -137,112 +138,5 @@ export class PredictService {
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
private getTableColumnsScript(databaseTable: DatabaseTable): string {
return `
SELECT COLUMN_NAME,DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
`;
}
private getTablesScript(databaseName: string): string {
return `
SELECT TABLE_NAME,TABLE_SCHEMA
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
`;
}
private getPredictScriptWithModelId(
modelId: number,
columns: PredictColumn[],
outputColumns: PredictColumn[],
sourceTable: DatabaseTable,
importTable: DatabaseTable): string {
const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || '');
return `
DECLARE @model VARBINARY(max) = (
SELECT model
FROM ${threePartTableName}
WHERE model_id = ${modelId}
);
WITH predict_input
AS (
SELECT TOP 1000
${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
)
SELECT
${this.getPredictColumnNames(columns, 'predict_input')},
${this.getPredictInputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx)
WITH (
${this.getOutputParameters(outputColumns)}
) AS p
`;
}
private getPredictScriptWithModelBytes(
modelBytes: string,
columns: PredictColumn[],
outputColumns: PredictColumn[],
databaseNameTable: DatabaseTable): string {
return `
WITH predict_input
AS (
SELECT TOP 1000
${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
${this.getPredictColumnNames(columns, 'predict_input')},
${this.getPredictInputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx)
WITH (
${this.getOutputParameters(outputColumns)}
) AS p
`;
}
private getEscapedColumnName(tableName: string, columnName: string): string {
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
}
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
const column = this.getEscapedColumnName(tableName, c.columnName);
let columnName = c.dataType !== c.paramType ? `cast(${column} as ${c.paramType})`
: `${column}`;
return `${columnName} AS ${c.paramName}`;
}).join(',\n');
}
private getPredictInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return this.getColumnName(tableName, c.paramName || '', c.columnName);
}).join(',\n');
}
private getColumnName(tableName: string, columnName: string, displayName: string) {
const column = this.getEscapedColumnName(tableName, columnName);
return columnName && columnName !== displayName ?
`${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column;
}
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.paramName ? `${this.getEscapedColumnName(tableName, c.paramName)}`
: `${this.getEscapedColumnName(tableName, c.columnName)}`;
}).join(',\n');
}
private getOutputParameters(columns: PredictColumn[]) {
return columns.map(c => {
return `${c.paramName} ${c.dataType}`;
}).join(',\n');
}
}