mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-01-20 09:35:38 -05:00
ML extension - Improving predict parameter mapping experience (#10264)
This commit is contained in:
@@ -39,7 +39,7 @@ export class ModelPythonClient {
|
||||
/**
|
||||
* Installs dependencies for python client
|
||||
*/
|
||||
private async installDependencies(): Promise<void> {
|
||||
public async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [
|
||||
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
|
||||
}
|
||||
@@ -49,7 +49,6 @@ export class ModelPythonClient {
|
||||
* @param modelPath Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
|
||||
await this.installDependencies();
|
||||
return await this.executeModelParametersScripts(modelPath);
|
||||
}
|
||||
|
||||
@@ -61,6 +60,9 @@ export class ModelPythonClient {
|
||||
'import json',
|
||||
`onnx_model_path = '${modelFolderPath}'`,
|
||||
`onnx_model = onnx.load_model(onnx_model_path)`,
|
||||
`type_list = ['undefined',
|
||||
'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double',
|
||||
'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`,
|
||||
`type_map = {
|
||||
onnx.TensorProto.DataType.FLOAT: 'real',
|
||||
onnx.TensorProto.DataType.UINT8: 'tinyint',
|
||||
@@ -76,13 +78,14 @@ export class ModelPythonClient {
|
||||
`def addParameters(list, paramType):
|
||||
for id, p in enumerate(list):
|
||||
p_type = ''
|
||||
|
||||
if p.type.tensor_type.elem_type in type_map:
|
||||
p_type = type_map[p.type.tensor_type.elem_type]
|
||||
|
||||
value = p.type.tensor_type.elem_type
|
||||
if value in type_map:
|
||||
p_type = type_map[value]
|
||||
name = type_list[value]
|
||||
parameters[paramType].append({
|
||||
'name': p.name,
|
||||
'type': p_type
|
||||
'type': p_type,
|
||||
'originalType': name
|
||||
})`,
|
||||
|
||||
'addParameters(onnx_model.graph.input, "inputs")',
|
||||
|
||||
Reference in New Issue
Block a user