mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
ML extension - Improving predict parameter mapping experience (#10264)
This commit is contained in:
@@ -10,6 +10,7 @@ export interface TableColumn {
|
||||
|
||||
export interface PredictColumn extends TableColumn {
|
||||
paramName?: string;
|
||||
paramType?: string;
|
||||
}
|
||||
|
||||
export interface DatabaseTable {
|
||||
|
||||
@@ -35,6 +35,25 @@ export class PredictService {
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if server supports ONNX
|
||||
*/
|
||||
public async serverSupportOnnxModel(): Promise<boolean> {
|
||||
try {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
const serverInfo = await this._apiWrapper.getServerInfo(connection.connectionId);
|
||||
// Right now only Azure SQL Edge support Onnx
|
||||
//
|
||||
return serverInfo && serverInfo.engineEditionId === 9;
|
||||
}
|
||||
return false;
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates prediction script given model info and predict parameters
|
||||
* @param predictParams predict parameters
|
||||
@@ -157,7 +176,7 @@ AS (
|
||||
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
|
||||
)
|
||||
SELECT
|
||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
|
||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getPredictInputColumnNames(outputColumns, 'p')}
|
||||
FROM PREDICT(MODEL = @model, DATA = predict_input)
|
||||
WITH (
|
||||
${this.getOutputParameters(outputColumns)}
|
||||
@@ -186,7 +205,20 @@ WITH (
|
||||
`;
|
||||
}
|
||||
|
||||
private getEscapedColumnName(tableName: string, columnName: string): string {
|
||||
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
|
||||
}
|
||||
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
|
||||
return columns.map(c => {
|
||||
const column = this.getEscapedColumnName(tableName, c.columnName);
|
||||
let columnName = c.dataType !== c.paramType ? `cast(${column} as ${c.paramType})`
|
||||
: `${column}`;
|
||||
return `${columnName} AS ${c.paramName}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getPredictInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return this.getColumnName(tableName, c.paramName || '', c.columnName);
|
||||
}).join(',\n');
|
||||
@@ -199,12 +231,15 @@ WITH (
|
||||
}
|
||||
|
||||
private getColumnName(tableName: string, columnName: string, displayName: string) {
|
||||
return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`;
|
||||
const column = this.getEscapedColumnName(tableName, columnName);
|
||||
return columnName && columnName !== displayName ?
|
||||
`${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column;
|
||||
}
|
||||
|
||||
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`;
|
||||
return c.paramName ? `${this.getEscapedColumnName(tableName, c.paramName)}`
|
||||
: `${this.getEscapedColumnName(tableName, c.columnName)}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user