ML extension - Improving predict parameter mapping experience (#10264)

This commit is contained in:
Leila Lali
2020-05-10 18:10:17 -07:00
committed by GitHub
parent f6e7b56946
commit 3d2d791f18
44 changed files with 782 additions and 388 deletions

View File

@@ -39,7 +39,7 @@ export class ModelPythonClient {
/**
* Installs dependencies for python client
*/
private async installDependencies(): Promise<void> {
public async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
}
@@ -49,7 +49,6 @@ export class ModelPythonClient {
* @param modelPath Loads model parameters
*/
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
await this.installDependencies();
return await this.executeModelParametersScripts(modelPath);
}
@@ -61,6 +60,9 @@ export class ModelPythonClient {
'import json',
`onnx_model_path = '${modelFolderPath}'`,
`onnx_model = onnx.load_model(onnx_model_path)`,
`type_list = ['undefined',
'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double',
'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`,
`type_map = {
onnx.TensorProto.DataType.FLOAT: 'real',
onnx.TensorProto.DataType.UINT8: 'tinyint',
@@ -76,13 +78,14 @@ export class ModelPythonClient {
`def addParameters(list, paramType):
for id, p in enumerate(list):
p_type = ''
if p.type.tensor_type.elem_type in type_map:
p_type = type_map[p.type.tensor_type.elem_type]
value = p.type.tensor_type.elem_type
if value in type_map:
p_type = type_map[value]
name = type_list[value]
parameters[paramType].append({
'name': p.name,
'type': p_type
'type': p_type,
'originalType': name
})`,
'addParameters(onnx_model.graph.input, "inputs")',