Machine Learning Services - Model detection in predict wizard (#9609)

* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
Leila Lali
2020-03-25 13:18:19 -07:00
committed by GitHub
parent 176edde2aa
commit ab82c04766
44 changed files with 2265 additions and 376 deletions

View File

@@ -3,10 +3,13 @@
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export interface PredictColumn {
name: string;
export interface TableColumn {
columnName: string;
dataType?: string;
displayName?: string;
}
export interface PredictColumn extends TableColumn {
paramName?: string;
}
export interface DatabaseTable {

View File

@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import * as utils from '../common/utils';
import { RegisteredModel } from '../modelManagement/interfaces';
import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces';
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
import { Config } from '../configurations/config';
/**
@@ -98,15 +98,18 @@ export class PredictService {
*Returns list of column names of a database
* @param databaseTable table info
*/
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<string[]> {
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<TableColumn[]> {
let connection = await this.getCurrentConnection();
let list: string[] = [];
let list: TableColumn[] = [];
if (connection && databaseTable.databaseName) {
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push(row[0].displayValue);
list.push({
columnName: row[0].displayValue,
dataType: row[1].displayValue
});
});
}
}
@@ -119,7 +122,7 @@ export class PredictService {
private getTableColumnsScript(databaseTable: DatabaseTable): string {
return `
SELECT COLUMN_NAME,*
SELECT COLUMN_NAME,DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
@@ -149,14 +152,14 @@ DECLARE @model VARBINARY(max) = (
WITH predict_input
AS (
SELECT TOP 1000
${this.getColumnNames(columns, 'pi')}
${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = @model, DATA = predict_input)
WITH (
${this.getColumnTypes(outputColumns)}
${this.getOutputParameters(outputColumns)}
) AS p
`;
}
@@ -170,33 +173,43 @@ WITH (
WITH predict_input
AS (
SELECT TOP 1000
${this.getColumnNames(columns, 'pi')}
${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
WITH (
${this.getColumnTypes(outputColumns)}
${this.getOutputParameters(outputColumns)}
) AS p
`;
}
private getColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`;
}).join(',\n');
}
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`;
return this.getColumnName(tableName, c.paramName || '', c.columnName);
}).join(',\n');
}
private getColumnTypes(columns: PredictColumn[]) {
private getOutputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return `${c.name} ${c.dataType}`;
return this.getColumnName(tableName, c.columnName, c.paramName || '');
}).join(',\n');
}
private getColumnName(tableName: string, columnName: string, displayName: string) {
return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`;
}
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`;
}).join(',\n');
}
private getOutputParameters(columns: PredictColumn[]) {
return columns.map(c => {
return `${c.paramName} ${c.dataType}`;
}).join(',\n');
}
}