ML extension - Improving predict parameter mapping experience (#10264)

This commit is contained in:
Leila Lali
2020-05-10 18:10:17 -07:00
committed by GitHub
parent f6e7b56946
commit 3d2d791f18
44 changed files with 782 additions and 388 deletions

View File

@@ -61,12 +61,19 @@ export class DeployedModelService {
*/
public async downloadModel(model: ImportedModel): Promise<string> {
let connection = await this.getCurrentConnection();
let fileContent: string = '';
if (connection) {
const query = queries.getModelContentQuery(model);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
const content = result.rows[0][0].displayValue;
return await utils.writeFileFromHex(content);
for (let index = 0; index < result.rows[0].length; index++) {
const column = result.rows[0][index];
let content = column.displayValue;
content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content;
fileContent = fileContent + content;
}
return await utils.writeFileFromHex(fileContent);
} else {
throw Error(constants.invalidModelToSelectError);
}
@@ -170,6 +177,13 @@ export class DeployedModelService {
}
}
/**
* Installs the dependencies required for model management
*/
public async installDependencies(): Promise<void> {
await this._modelClient.installDependencies();
}
public async getRecentImportTable(): Promise<DatabaseTable> {
let connection = await this.getCurrentConnection();
let table: DatabaseTable | undefined;
@@ -209,6 +223,7 @@ export class DeployedModelService {
deploymentTime: row[7].displayValue,
deployedBy: row[8].displayValue,
runId: row[9].displayValue,
contentLength: +row[10].displayValue,
table: table
};
}

View File

@@ -18,6 +18,7 @@ export interface ListWorkspaceModelsResult extends Array<WorkspaceModel> {
*/
export interface WorkspaceModel extends Resource {
framework?: string;
description?: string;
frameworkVersion?: string;
createdBy?: string;
createdTime?: string;
@@ -52,12 +53,14 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
export interface ImportedModel extends ImportedModelDetails {
id: number;
content?: string;
contentLength?: number;
table: DatabaseTable;
}
export interface ModelParameter {
name: string;
type: string;
originalType?: string;
}
export interface ModelParameters {

View File

@@ -21,7 +21,12 @@ export class ModelConfigRecent {
}
public storeModelTable(connection: azdata.connection.ConnectionProfile, table: DatabaseTable): void {
this._memento.update(this.getKey(connection), table);
if (connection && table?.databaseName && table?.tableName && table?.schema) {
const current = this.getModelTable(connection);
if (!current || current.databaseName !== table.databaseName || current.tableName !== table.tableName || current.schema !== table.schema) {
this._memento.update(this.getKey(connection), table);
}
}
}
private getKey(connection: azdata.connection.ConnectionProfile): string {

View File

@@ -39,7 +39,7 @@ export class ModelPythonClient {
/**
* Installs dependencies for python client
*/
private async installDependencies(): Promise<void> {
public async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
}
@@ -49,7 +49,6 @@ export class ModelPythonClient {
* @param modelPath Loads model parameters
*/
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
await this.installDependencies();
return await this.executeModelParametersScripts(modelPath);
}
@@ -61,6 +60,9 @@ export class ModelPythonClient {
'import json',
`onnx_model_path = '${modelFolderPath}'`,
`onnx_model = onnx.load_model(onnx_model_path)`,
`type_list = ['undefined',
'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double',
'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`,
`type_map = {
onnx.TensorProto.DataType.FLOAT: 'real',
onnx.TensorProto.DataType.UINT8: 'tinyint',
@@ -76,13 +78,14 @@ export class ModelPythonClient {
`def addParameters(list, paramType):
for id, p in enumerate(list):
p_type = ''
if p.type.tensor_type.elem_type in type_map:
p_type = type_map[p.type.tensor_type.elem_type]
value = p.type.tensor_type.elem_type
if value in type_map:
p_type = type_map[value]
name = type_list[value]
parameters[paramType].append({
'name': p.name,
'type': p_type
'type': p_type,
'originalType': name
})`,
'addParameters(onnx_model.graph.input, "inputs")',

View File

@@ -144,12 +144,37 @@ export function getInsertModelQuery(model: ImportedModel, table: DatabaseTable):
`;
}
/**
* Returns the query for loading model content from database
* @param model model information
*/
export function getModelContentQuery(model: ImportedModel): string {
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
const len = model.contentLength !== undefined ? model.contentLength : 0;
const maxLength = 1000;
let numberOfColumns = len / maxLength;
// The query provider doesn't return the whole file bites if too big. so loading the bites it blocks
// and merge together to load the file
numberOfColumns = numberOfColumns <= 0 ? 1 : numberOfColumns;
let columns: string[] = [];
let fileIndex = 0;
for (let index = 0; index < numberOfColumns; index++) {
const length = fileIndex === 0 ? maxLength + 1 : maxLength;
columns.push(`substring(@str, ${fileIndex}, ${length}) as d${index}`);
fileIndex = fileIndex + length;
}
if (fileIndex < len) {
columns.push(`substring(@str, ${fileIndex}, ${maxLength}) as d${columns.length}`);
}
return `
SELECT model
DECLARE @str varbinary(max)
SELECT @str=model
FROM ${threePartTableName}
WHERE model_id = ${model.id};
select ${columns.join(',')}
`;
}
@@ -190,6 +215,6 @@ export function getDeleteModelQuery(model: ImportedModel): string {
`;
}
export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id';
export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id, len(model)';