mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-01 17:23:35 -05:00
@@ -44,11 +44,28 @@ export const registeredModelsTableName = 'registeredModelsTableName';
|
||||
export const rPathConfigKey = 'rPath';
|
||||
export const adsPythonBundleVersion = '0.0.1';
|
||||
|
||||
// TSQL
|
||||
//
|
||||
|
||||
// The data types that are supported to convert model's parameters to SQL data
|
||||
export const supportedDataTypes = [
|
||||
'BIGINT',
|
||||
'INT',
|
||||
'SMALLINT',
|
||||
'REAL',
|
||||
'FLOAT',
|
||||
'VARCHAR(MAX)',
|
||||
'BIT'
|
||||
];
|
||||
export const varcharMax = 'VARCHAR(MAX)';
|
||||
export const varcharDefaultLength = 100;
|
||||
|
||||
// Localized texts
|
||||
//
|
||||
export const msgYes = localize('msgYes', "Yes");
|
||||
export const msgNo = localize('msgNo', "No");
|
||||
export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed.");
|
||||
export const notebookExtensionFailedError = localize('notebookExtensionFailedError', "The extension failed to load because of it's dependency to Notebook extension. Please check the output log for Notebook extension to get more details");
|
||||
export const verifyOdbcDriverError = localize('mls.verifyOdbcDriverError.error', "'{0}' is required for package management. Please make sure it is installed and set up correctly.", supportedODBCDriver);
|
||||
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
|
||||
export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); }
|
||||
|
||||
@@ -62,12 +62,18 @@ export default class MainController implements vscode.Disposable {
|
||||
* Returns an instance of Server Installation from notebook extension
|
||||
*/
|
||||
private async getNotebookExtensionApis(): Promise<nbExtensionApis.IExtensionApi> {
|
||||
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
|
||||
if (nbExtension) {
|
||||
await nbExtension.activate();
|
||||
return (nbExtension.exports as nbExtensionApis.IExtensionApi);
|
||||
} else {
|
||||
throw new Error(constants.notebookExtensionNotLoaded);
|
||||
try {
|
||||
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
|
||||
if (nbExtension) {
|
||||
await nbExtension.activate();
|
||||
return (nbExtension.exports as nbExtensionApis.IExtensionApi);
|
||||
} else {
|
||||
throw new Error(constants.notebookExtensionNotLoaded);
|
||||
}
|
||||
} catch (err) {
|
||||
this._outputChannel.appendLine(constants.notebookExtensionFailedError);
|
||||
this._apiWrapper.showErrorMessage(constants.notebookExtensionFailedError);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,11 +82,12 @@ export class ModelPythonClient {
|
||||
if value in type_map:
|
||||
p_type = type_map[value]
|
||||
name = type_list[value]
|
||||
parameters[paramType].append({
|
||||
'name': p.name,
|
||||
'type': p_type,
|
||||
'originalType': name
|
||||
})`,
|
||||
if name != 'undefined':
|
||||
parameters[paramType].append({
|
||||
'name': p.name,
|
||||
'type': p_type,
|
||||
'originalType': name
|
||||
})`,
|
||||
|
||||
'addParameters(onnx_model.graph.input, "inputs")',
|
||||
'addParameters(onnx_model.graph.output, "outputs")',
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
export interface TableColumn {
|
||||
columnName: string;
|
||||
dataType?: string;
|
||||
maxLength?: number;
|
||||
}
|
||||
|
||||
export interface PredictColumn extends TableColumn {
|
||||
|
||||
@@ -127,7 +127,8 @@ export class PredictService {
|
||||
result.rows.forEach(row => {
|
||||
list.push({
|
||||
columnName: row[0].displayValue,
|
||||
dataType: row[1].displayValue.toLocaleUpperCase()
|
||||
dataType: row[1].displayValue.toLocaleUpperCase(),
|
||||
maxLength: row[2].isNull ? undefined : +row[2].displayValue.toLocaleUpperCase()
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -5,10 +5,11 @@
|
||||
|
||||
import * as utils from '../common/utils';
|
||||
import { PredictColumn, DatabaseTable } from './interfaces';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
export function getTableColumnsScript(databaseTable: DatabaseTable): string {
|
||||
return `
|
||||
SELECT COLUMN_NAME,DATA_TYPE
|
||||
SELECT COLUMN_NAME,DATA_TYPE,CHARACTER_MAXIMUM_LENGTH
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
|
||||
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
|
||||
@@ -57,13 +58,13 @@ export function getPredictScriptWithModelBytes(
|
||||
modelBytes: string,
|
||||
columns: PredictColumn[],
|
||||
outputColumns: PredictColumn[],
|
||||
databaseNameTable: DatabaseTable): string {
|
||||
sourceTable: DatabaseTable): string {
|
||||
return `
|
||||
WITH predict_input
|
||||
AS (
|
||||
SELECT TOP 1000
|
||||
${getInputColumnNames(columns, 'pi')}
|
||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] AS pi
|
||||
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] AS pi
|
||||
)
|
||||
SELECT
|
||||
${getPredictColumnNames(columns, 'predict_input')},
|
||||
@@ -78,11 +79,14 @@ ${getOutputParameters(outputColumns)}
|
||||
export function getEscapedColumnName(tableName: string, columnName: string): string {
|
||||
return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`;
|
||||
}
|
||||
|
||||
export function getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
|
||||
return columns.map(c => {
|
||||
const column = getEscapedColumnName(tableName, c.columnName);
|
||||
let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${c.paramType})`
|
||||
const maxLength = c.maxLength !== undefined ? c.maxLength : constants.varcharDefaultLength;
|
||||
let paramType = c.paramType === constants.varcharMax ? `VARCHAR(${maxLength})` : c.paramType;
|
||||
let columnName = c.dataType !== c.paramType ? `CAST(${column} AS ${paramType})`
|
||||
: `${column}`;
|
||||
return `${columnName} AS ${c.paramName}`;
|
||||
}).join(',\n ');
|
||||
|
||||
@@ -6,12 +6,13 @@
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as Queries from '../../prediction/queries';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { ImportedModel } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn, PredictColumn } from '../../prediction/interfaces';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
@@ -114,11 +115,13 @@ describe('PredictService', () => {
|
||||
const expected: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'INT'
|
||||
dataType: 'INT',
|
||||
maxLength: undefined
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'VARCHAR'
|
||||
dataType: 'VARCHAR',
|
||||
maxLength: 10
|
||||
}
|
||||
];
|
||||
const table: DatabaseTable =
|
||||
@@ -141,6 +144,11 @@ describe('PredictService', () => {
|
||||
displayValue: 'int',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '',
|
||||
isNull: true,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
], [
|
||||
{
|
||||
@@ -152,6 +160,11 @@ describe('PredictService', () => {
|
||||
displayValue: 'varchar',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '10',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]]
|
||||
};
|
||||
@@ -175,12 +188,14 @@ describe('PredictService', () => {
|
||||
{
|
||||
paramName: 'p1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
columnName: '',
|
||||
maxLength: undefined
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
dataType: 'varchar',
|
||||
columnName: ''
|
||||
columnName: '',
|
||||
maxLength: 10
|
||||
}
|
||||
],
|
||||
outputColumns: [
|
||||
@@ -298,4 +313,35 @@ describe('PredictService', () => {
|
||||
should.notEqual(actual, undefined);
|
||||
should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
|
||||
});
|
||||
|
||||
it('getInputColumnNames should user column max length for varchar type', async function (): Promise<void> {
|
||||
const columns: PredictColumn[] = [
|
||||
{
|
||||
paramName: 'p1',
|
||||
paramType: 'VARCHAR(MAX)',
|
||||
columnName: 'c1',
|
||||
dataType: 'VARCHAR',
|
||||
maxLength: 20
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
paramType: 'VARCHAR(MAX)',
|
||||
columnName: 'c2',
|
||||
dataType: 'DATETIME',
|
||||
maxLength: undefined
|
||||
},
|
||||
{
|
||||
paramName: 'p3',
|
||||
paramType: 'INT',
|
||||
columnName: 'c2',
|
||||
dataType: 'INT',
|
||||
maxLength: undefined
|
||||
},
|
||||
];
|
||||
|
||||
const tableName = 'tbname';
|
||||
let actual = Queries.getInputColumnNames(columns, tableName);
|
||||
let expected =`CAST([tbname].[c1] AS VARCHAR(20)) AS p1,\n\tCAST([tbname].[c2] AS VARCHAR(100)) AS p2,\n\t[tbname].[c2] AS p3`;
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -180,7 +180,7 @@ describe('Predict Wizard', () => {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
|
||||
}
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data, undefined);
|
||||
should.notEqual(view.azureModelsComponent?.data, undefined, 'Data from Azure component should not be null');
|
||||
|
||||
if (view.modelBrowsePage) {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
|
||||
@@ -188,19 +188,21 @@ describe('Predict Wizard', () => {
|
||||
await view.refresh();
|
||||
testContext.onClick.fire(undefined);
|
||||
|
||||
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined);
|
||||
|
||||
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels, 'Model source should be registered models');
|
||||
should.notEqual(view.localModelsComponent?.data, undefined, 'Data from local model component should not be null');
|
||||
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined, 'Data from registered model component should not be null');
|
||||
if (view.modelBrowsePage?.registeredModelsComponent?.data) {
|
||||
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1);
|
||||
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1, 'Data from registered model component should not be empty');
|
||||
}
|
||||
|
||||
|
||||
should.notEqual(await view.getModelFileName(), undefined);
|
||||
should.notEqual(await view.getModelFileName(), undefined, 'Model file name should not be null');
|
||||
await view.columnsSelectionPage?.onEnter();
|
||||
await view.columnsSelectionPage?.inputColumnsComponent?.loadWithTable(tableNames[0]);
|
||||
|
||||
should.notEqual(view.columnsSelectionPage?.data, undefined);
|
||||
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name);
|
||||
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length);
|
||||
should.notEqual(view.columnsSelectionPage?.data, undefined, 'Data from column selection component should not be null');
|
||||
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, `unexpected number of inputs. ${view.columnsSelectionPage?.data?.inputColumns?.length}` );
|
||||
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length, `unexpected number of outputs. ${view.columnsSelectionPage?.data?.outputColumns?.length}`);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -19,15 +19,6 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
private _parameters: PredictColumn[] = [];
|
||||
private _loader: azdata.LoadingComponent;
|
||||
private _dataTypes: string[] = [
|
||||
'BIGINT',
|
||||
'INT',
|
||||
'SMALLINT',
|
||||
'REAL',
|
||||
'FLOAT',
|
||||
'VARCHAR(MAX)',
|
||||
'BIT'
|
||||
];
|
||||
|
||||
|
||||
/**
|
||||
@@ -171,7 +162,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
this._parameters = [];
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (this._table) {
|
||||
if (this._table && table && table.tableName !== constants.selectTableTitle) {
|
||||
if (this._forInput) {
|
||||
let columns: TableColumn[];
|
||||
try {
|
||||
@@ -196,8 +187,8 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
|
||||
if (this._table) {
|
||||
if (!this._forInput) {
|
||||
if (modelParameters?.outputs && this._dataTypes) {
|
||||
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
|
||||
if (modelParameters?.outputs && constants.supportedDataTypes) {
|
||||
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, constants.supportedDataTypes)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -299,7 +290,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
nameInput.value = column;
|
||||
|
||||
if (column) {
|
||||
this._parameters.push({ columnName: column.name, paramName: name, paramType: modelParameter.type });
|
||||
this._parameters.push({ columnName: column.name, paramName: name, paramType: modelParameter.type, maxLength: currentColumn?.maxLength });
|
||||
}
|
||||
const inputContainer = this._modelBuilder.flexContainer().withLayout({
|
||||
flexFlow: 'row',
|
||||
@@ -323,6 +314,10 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
||||
let selectedRow = this._parameters.find(x => x.paramName === name);
|
||||
if (selectedRow) {
|
||||
selectedRow.columnName = value || '';
|
||||
let tableColumn = columns.find(x => x.columnName === value);
|
||||
if (tableColumn) {
|
||||
selectedRow.maxLength = tableColumn.maxLength;
|
||||
}
|
||||
}
|
||||
|
||||
const currentColumn = columns.find(x => x.columnName === value);
|
||||
|
||||
@@ -132,7 +132,11 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
|
||||
}
|
||||
|
||||
private async onTableSelected(): Promise<void> {
|
||||
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
|
||||
await this.loadWithTable(this.databaseTable);
|
||||
}
|
||||
|
||||
public async loadWithTable(table: DatabaseTable): Promise<void> {
|
||||
await this._columns?.loadInputs(this._modelParameters, table);
|
||||
}
|
||||
|
||||
private get databaseTable(): DatabaseTable {
|
||||
|
||||
Reference in New Issue
Block a user