ML extension - Improving predict parameter mapping experience (#10264)

This commit is contained in:
Leila Lali
2020-05-10 18:10:17 -07:00
committed by GitHub
parent f6e7b56946
commit 3d2d791f18
44 changed files with 782 additions and 388 deletions

View File

@@ -10,6 +10,7 @@ export interface TableColumn {
export interface PredictColumn extends TableColumn {
paramName?: string;
paramType?: string;
}
export interface DatabaseTable {

View File

@@ -35,6 +35,25 @@ export class PredictService {
return [];
}
/**
* Returns true if server supports ONNX
*/
public async serverSupportOnnxModel(): Promise<boolean> {
try {
let connection = await this.getCurrentConnection();
if (connection) {
const serverInfo = await this._apiWrapper.getServerInfo(connection.connectionId);
// Right now only Azure SQL Edge support Onnx
//
return serverInfo && serverInfo.engineEditionId === 9;
}
return false;
} catch (error) {
console.log(error);
return false;
}
}
/**
* Generates prediction script given model info and predict parameters
* @param predictParams predict parameters
@@ -157,7 +176,7 @@ AS (
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
)
SELECT
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getPredictInputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = @model, DATA = predict_input)
WITH (
${this.getOutputParameters(outputColumns)}
@@ -186,7 +205,20 @@ WITH (
`;
}
private getEscapedColumnName(tableName: string, columnName: string): string {
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
}
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
const column = this.getEscapedColumnName(tableName, c.columnName);
let columnName = c.dataType !== c.paramType ? `cast(${column} as ${c.paramType})`
: `${column}`;
return `${columnName} AS ${c.paramName}`;
}).join(',\n');
}
private getPredictInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return this.getColumnName(tableName, c.paramName || '', c.columnName);
}).join(',\n');
@@ -199,12 +231,15 @@ WITH (
}
private getColumnName(tableName: string, columnName: string, displayName: string) {
return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`;
const column = this.getEscapedColumnName(tableName, columnName);
return columnName && columnName !== displayName ?
`${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column;
}
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`;
return c.paramName ? `${this.getEscapedColumnName(tableName, c.paramName)}`
: `${this.getEscapedColumnName(tableName, c.columnName)}`;
}).join(',\n');
}