mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 18:46:40 -05:00
ML - Fixed script formatting for prediction (#11767)
* Fixed script formatting for prediction
This commit is contained in:
@@ -64,13 +64,13 @@ export class ModelPythonClient {
|
|||||||
'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double',
|
'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double',
|
||||||
'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`,
|
'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`,
|
||||||
`type_map = {
|
`type_map = {
|
||||||
onnx.TensorProto.DataType.FLOAT: 'real',
|
onnx.TensorProto.DataType.FLOAT: 'REAL',
|
||||||
onnx.TensorProto.DataType.UINT8: 'tinyint',
|
onnx.TensorProto.DataType.UINT8: 'TINYINT',
|
||||||
onnx.TensorProto.DataType.INT16: 'smallint',
|
onnx.TensorProto.DataType.INT16: 'SMALLINT',
|
||||||
onnx.TensorProto.DataType.INT32: 'int',
|
onnx.TensorProto.DataType.INT32: 'INT',
|
||||||
onnx.TensorProto.DataType.INT64: 'bigint',
|
onnx.TensorProto.DataType.INT64: 'BIGINT',
|
||||||
onnx.TensorProto.DataType.STRING: 'varchar(MAX)',
|
onnx.TensorProto.DataType.STRING: 'VARCHAR(MAX)',
|
||||||
onnx.TensorProto.DataType.DOUBLE: 'float'}`,
|
onnx.TensorProto.DataType.DOUBLE: 'FLOAT'}`,
|
||||||
`parameters = {
|
`parameters = {
|
||||||
"inputs": [],
|
"inputs": [],
|
||||||
"outputs": []
|
"outputs": []
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import { ApiWrapper } from '../common/apiWrapper';
|
|||||||
import { QueryRunner } from '../common/queryRunner';
|
import { QueryRunner } from '../common/queryRunner';
|
||||||
import * as utils from '../common/utils';
|
import * as utils from '../common/utils';
|
||||||
import { ImportedModel } from '../modelManagement/interfaces';
|
import { ImportedModel } from '../modelManagement/interfaces';
|
||||||
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
|
import { PredictParameters, DatabaseTable, TableColumn } from '../prediction/interfaces';
|
||||||
|
import * as queries from './queries';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Service to make prediction
|
* Service to make prediction
|
||||||
@@ -67,7 +68,7 @@ export class PredictService {
|
|||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
let query = '';
|
let query = '';
|
||||||
if (registeredModel && registeredModel.id) {
|
if (registeredModel && registeredModel.id) {
|
||||||
query = this.getPredictScriptWithModelId(
|
query = queries.getPredictScriptWithModelId(
|
||||||
registeredModel.id,
|
registeredModel.id,
|
||||||
predictParams.inputColumns || [],
|
predictParams.inputColumns || [],
|
||||||
predictParams.outputColumns || [],
|
predictParams.outputColumns || [],
|
||||||
@@ -75,7 +76,7 @@ export class PredictService {
|
|||||||
registeredModel.table);
|
registeredModel.table);
|
||||||
} else if (filePath) {
|
} else if (filePath) {
|
||||||
let modelBytes = await utils.readFileInHex(filePath || '');
|
let modelBytes = await utils.readFileInHex(filePath || '');
|
||||||
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
|
query = queries.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
|
||||||
predictParams.outputColumns || [],
|
predictParams.outputColumns || [],
|
||||||
predictParams);
|
predictParams);
|
||||||
}
|
}
|
||||||
@@ -97,7 +98,7 @@ export class PredictService {
|
|||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
let list: DatabaseTable[] = [];
|
let list: DatabaseTable[] = [];
|
||||||
if (connection) {
|
if (connection) {
|
||||||
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName));
|
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, queries.getTablesScript(databaseName));
|
||||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||||
if (result && result.rows && result.rows.length > 0) {
|
if (result && result.rows && result.rows.length > 0) {
|
||||||
result.rows.forEach(row => {
|
result.rows.forEach(row => {
|
||||||
@@ -120,13 +121,13 @@ export class PredictService {
|
|||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
let list: TableColumn[] = [];
|
let list: TableColumn[] = [];
|
||||||
if (connection && databaseTable.databaseName) {
|
if (connection && databaseTable.databaseName) {
|
||||||
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
|
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, queries.getTableColumnsScript(databaseTable));
|
||||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||||
if (result && result.rows && result.rows.length > 0) {
|
if (result && result.rows && result.rows.length > 0) {
|
||||||
result.rows.forEach(row => {
|
result.rows.forEach(row => {
|
||||||
list.push({
|
list.push({
|
||||||
columnName: row[0].displayValue,
|
columnName: row[0].displayValue,
|
||||||
dataType: row[1].displayValue
|
dataType: row[1].displayValue.toLocaleUpperCase()
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -137,112 +138,5 @@ export class PredictService {
|
|||||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||||
return await this._apiWrapper.getCurrentConnection();
|
return await this._apiWrapper.getCurrentConnection();
|
||||||
}
|
}
|
||||||
|
|
||||||
private getTableColumnsScript(databaseTable: DatabaseTable): string {
|
|
||||||
return `
|
|
||||||
SELECT COLUMN_NAME,DATA_TYPE
|
|
||||||
FROM INFORMATION_SCHEMA.COLUMNS
|
|
||||||
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
|
|
||||||
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
|
|
||||||
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
|
|
||||||
`;
|
|
||||||
}
|
|
||||||
|
|
||||||
private getTablesScript(databaseName: string): string {
|
|
||||||
return `
|
|
||||||
SELECT TABLE_NAME,TABLE_SCHEMA
|
|
||||||
FROM INFORMATION_SCHEMA.TABLES
|
|
||||||
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
|
|
||||||
`;
|
|
||||||
}
|
|
||||||
|
|
||||||
private getPredictScriptWithModelId(
|
|
||||||
modelId: number,
|
|
||||||
columns: PredictColumn[],
|
|
||||||
outputColumns: PredictColumn[],
|
|
||||||
sourceTable: DatabaseTable,
|
|
||||||
importTable: DatabaseTable): string {
|
|
||||||
const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || '');
|
|
||||||
return `
|
|
||||||
DECLARE @model VARBINARY(max) = (
|
|
||||||
SELECT model
|
|
||||||
FROM ${threePartTableName}
|
|
||||||
WHERE model_id = ${modelId}
|
|
||||||
);
|
|
||||||
WITH predict_input
|
|
||||||
AS (
|
|
||||||
SELECT TOP 1000
|
|
||||||
${this.getInputColumnNames(columns, 'pi')}
|
|
||||||
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
${this.getPredictColumnNames(columns, 'predict_input')},
|
|
||||||
${this.getPredictInputColumnNames(outputColumns, 'p')}
|
|
||||||
FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx)
|
|
||||||
WITH (
|
|
||||||
${this.getOutputParameters(outputColumns)}
|
|
||||||
) AS p
|
|
||||||
`;
|
|
||||||
}
|
|
||||||
|
|
||||||
private getPredictScriptWithModelBytes(
|
|
||||||
modelBytes: string,
|
|
||||||
columns: PredictColumn[],
|
|
||||||
outputColumns: PredictColumn[],
|
|
||||||
databaseNameTable: DatabaseTable): string {
|
|
||||||
return `
|
|
||||||
WITH predict_input
|
|
||||||
AS (
|
|
||||||
SELECT TOP 1000
|
|
||||||
${this.getInputColumnNames(columns, 'pi')}
|
|
||||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
|
||||||
)
|
|
||||||
SELECT
|
|
||||||
${this.getPredictColumnNames(columns, 'predict_input')},
|
|
||||||
${this.getPredictInputColumnNames(outputColumns, 'p')}
|
|
||||||
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx)
|
|
||||||
WITH (
|
|
||||||
${this.getOutputParameters(outputColumns)}
|
|
||||||
) AS p
|
|
||||||
`;
|
|
||||||
}
|
|
||||||
|
|
||||||
private getEscapedColumnName(tableName: string, columnName: string): string {
|
|
||||||
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
|
|
||||||
}
|
|
||||||
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
|
||||||
|
|
||||||
return columns.map(c => {
|
|
||||||
const column = this.getEscapedColumnName(tableName, c.columnName);
|
|
||||||
let columnName = c.dataType !== c.paramType ? `cast(${column} as ${c.paramType})`
|
|
||||||
: `${column}`;
|
|
||||||
return `${columnName} AS ${c.paramName}`;
|
|
||||||
}).join(',\n');
|
|
||||||
}
|
|
||||||
|
|
||||||
private getPredictInputColumnNames(columns: PredictColumn[], tableName: string) {
|
|
||||||
return columns.map(c => {
|
|
||||||
return this.getColumnName(tableName, c.paramName || '', c.columnName);
|
|
||||||
}).join(',\n');
|
|
||||||
}
|
|
||||||
|
|
||||||
private getColumnName(tableName: string, columnName: string, displayName: string) {
|
|
||||||
const column = this.getEscapedColumnName(tableName, columnName);
|
|
||||||
return columnName && columnName !== displayName ?
|
|
||||||
`${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column;
|
|
||||||
}
|
|
||||||
|
|
||||||
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
|
|
||||||
return columns.map(c => {
|
|
||||||
return c.paramName ? `${this.getEscapedColumnName(tableName, c.paramName)}`
|
|
||||||
: `${this.getEscapedColumnName(tableName, c.columnName)}`;
|
|
||||||
}).join(',\n');
|
|
||||||
}
|
|
||||||
|
|
||||||
private getOutputParameters(columns: PredictColumn[]) {
|
|
||||||
return columns.map(c => {
|
|
||||||
return `${c.paramName} ${c.dataType}`;
|
|
||||||
}).join(',\n');
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
114
extensions/machine-learning/src/prediction/queries.ts
Normal file
114
extensions/machine-learning/src/prediction/queries.ts
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as utils from '../common/utils';
|
||||||
|
import { PredictColumn, DatabaseTable } from './interfaces';
|
||||||
|
|
||||||
|
export function getTableColumnsScript(databaseTable: DatabaseTable): string {
|
||||||
|
return `
|
||||||
|
SELECT COLUMN_NAME,DATA_TYPE
|
||||||
|
FROM INFORMATION_SCHEMA.COLUMNS
|
||||||
|
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
|
||||||
|
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
|
||||||
|
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getTablesScript(databaseName: string): string {
|
||||||
|
return `
|
||||||
|
SELECT TABLE_NAME,TABLE_SCHEMA
|
||||||
|
FROM INFORMATION_SCHEMA.TABLES
|
||||||
|
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getPredictScriptWithModelId(
|
||||||
|
modelId: number,
|
||||||
|
columns: PredictColumn[],
|
||||||
|
outputColumns: PredictColumn[],
|
||||||
|
sourceTable: DatabaseTable,
|
||||||
|
importTable: DatabaseTable): string {
|
||||||
|
const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || '');
|
||||||
|
return `
|
||||||
|
DECLARE @model VARBINARY(max) = (
|
||||||
|
SELECT model
|
||||||
|
FROM ${threePartTableName}
|
||||||
|
WHERE model_id = ${modelId}
|
||||||
|
);
|
||||||
|
WITH predict_input
|
||||||
|
AS (
|
||||||
|
SELECT TOP 1000
|
||||||
|
${getInputColumnNames(columns, 'pi')}
|
||||||
|
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] AS pi
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
${getPredictColumnNames(columns, 'predict_input')},
|
||||||
|
${getPredictInputColumnNames(outputColumns, 'p')}
|
||||||
|
FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx)
|
||||||
|
WITH (
|
||||||
|
${getOutputParameters(outputColumns)}
|
||||||
|
) AS p
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getPredictScriptWithModelBytes(
|
||||||
|
modelBytes: string,
|
||||||
|
columns: PredictColumn[],
|
||||||
|
outputColumns: PredictColumn[],
|
||||||
|
databaseNameTable: DatabaseTable): string {
|
||||||
|
return `
|
||||||
|
WITH predict_input
|
||||||
|
AS (
|
||||||
|
SELECT TOP 1000
|
||||||
|
${getInputColumnNames(columns, 'pi')}
|
||||||
|
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] AS pi
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
${getPredictColumnNames(columns, 'predict_input')},
|
||||||
|
${getPredictInputColumnNames(outputColumns, 'p')}
|
||||||
|
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx)
|
||||||
|
WITH (
|
||||||
|
${getOutputParameters(outputColumns)}
|
||||||
|
) AS p
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getEscapedColumnName(tableName: string, columnName: string): string {
|
||||||
|
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
|
||||||
|
}
|
||||||
|
export function getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||||
|
|
||||||
|
return columns.map(c => {
|
||||||
|
const column = getEscapedColumnName(tableName, c.columnName);
|
||||||
|
let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${c.paramType})`
|
||||||
|
: `${column}`;
|
||||||
|
return `${columnName} AS ${c.paramName}`;
|
||||||
|
}).join(',\n ');
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getPredictInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||||
|
return columns.map(c => {
|
||||||
|
return getColumnName(tableName, c.paramName || '', c.columnName);
|
||||||
|
}).join(',\n ');
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getColumnName(tableName: string, columnName: string, displayName: string) {
|
||||||
|
const column = getEscapedColumnName(tableName, columnName);
|
||||||
|
return columnName && columnName !== displayName ?
|
||||||
|
`${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getPredictColumnNames(columns: PredictColumn[], tableName: string) {
|
||||||
|
return columns.map(c => {
|
||||||
|
return c.paramName ? `${getEscapedColumnName(tableName, c.paramName)}`
|
||||||
|
: `${getEscapedColumnName(tableName, c.columnName)}`;
|
||||||
|
}).join(',\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getOutputParameters(columns: PredictColumn[]) {
|
||||||
|
return columns.map(c => {
|
||||||
|
return `${c.paramName} ${c.dataType}`;
|
||||||
|
}).join(',\n');
|
||||||
|
}
|
||||||
@@ -114,11 +114,11 @@ describe('PredictService', () => {
|
|||||||
const expected: TableColumn[] = [
|
const expected: TableColumn[] = [
|
||||||
{
|
{
|
||||||
columnName: 'c1',
|
columnName: 'c1',
|
||||||
dataType: 'int'
|
dataType: 'INT'
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
columnName: 'c2',
|
columnName: 'c2',
|
||||||
dataType: 'varchar'
|
dataType: 'VARCHAR'
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
const table: DatabaseTable =
|
const table: DatabaseTable =
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
|||||||
private _parameters: PredictColumn[] = [];
|
private _parameters: PredictColumn[] = [];
|
||||||
private _loader: azdata.LoadingComponent;
|
private _loader: azdata.LoadingComponent;
|
||||||
private _dataTypes: string[] = [
|
private _dataTypes: string[] = [
|
||||||
'bigint',
|
'BIGINT',
|
||||||
'int',
|
'INT',
|
||||||
'smallint',
|
'SMALLINT',
|
||||||
'real',
|
'REAL',
|
||||||
'float',
|
'FLOAT',
|
||||||
'varchar(MAX)',
|
'VARCHAR(MAX)',
|
||||||
'bit'
|
'BIT'
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user