mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 18:46:40 -05:00
ML extension - Improving predict parameter mapping experience (#10264)
This commit is contained in:
@@ -61,12 +61,19 @@ export class DeployedModelService {
|
||||
*/
|
||||
public async downloadModel(model: ImportedModel): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let fileContent: string = '';
|
||||
if (connection) {
|
||||
const query = queries.getModelContentQuery(model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const content = result.rows[0][0].displayValue;
|
||||
return await utils.writeFileFromHex(content);
|
||||
for (let index = 0; index < result.rows[0].length; index++) {
|
||||
const column = result.rows[0][index];
|
||||
let content = column.displayValue;
|
||||
content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content;
|
||||
fileContent = fileContent + content;
|
||||
}
|
||||
|
||||
return await utils.writeFileFromHex(fileContent);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToSelectError);
|
||||
}
|
||||
@@ -170,6 +177,13 @@ export class DeployedModelService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs the dependencies required for model management
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
await this._modelClient.installDependencies();
|
||||
}
|
||||
|
||||
public async getRecentImportTable(): Promise<DatabaseTable> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let table: DatabaseTable | undefined;
|
||||
@@ -209,6 +223,7 @@ export class DeployedModelService {
|
||||
deploymentTime: row[7].displayValue,
|
||||
deployedBy: row[8].displayValue,
|
||||
runId: row[9].displayValue,
|
||||
contentLength: +row[10].displayValue,
|
||||
table: table
|
||||
};
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ export interface ListWorkspaceModelsResult extends Array<WorkspaceModel> {
|
||||
*/
|
||||
export interface WorkspaceModel extends Resource {
|
||||
framework?: string;
|
||||
description?: string;
|
||||
frameworkVersion?: string;
|
||||
createdBy?: string;
|
||||
createdTime?: string;
|
||||
@@ -52,12 +53,14 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
||||
export interface ImportedModel extends ImportedModelDetails {
|
||||
id: number;
|
||||
content?: string;
|
||||
contentLength?: number;
|
||||
table: DatabaseTable;
|
||||
}
|
||||
|
||||
export interface ModelParameter {
|
||||
name: string;
|
||||
type: string;
|
||||
originalType?: string;
|
||||
}
|
||||
|
||||
export interface ModelParameters {
|
||||
|
||||
@@ -21,7 +21,12 @@ export class ModelConfigRecent {
|
||||
}
|
||||
|
||||
public storeModelTable(connection: azdata.connection.ConnectionProfile, table: DatabaseTable): void {
|
||||
this._memento.update(this.getKey(connection), table);
|
||||
if (connection && table?.databaseName && table?.tableName && table?.schema) {
|
||||
const current = this.getModelTable(connection);
|
||||
if (!current || current.databaseName !== table.databaseName || current.tableName !== table.tableName || current.schema !== table.schema) {
|
||||
this._memento.update(this.getKey(connection), table);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private getKey(connection: azdata.connection.ConnectionProfile): string {
|
||||
|
||||
@@ -39,7 +39,7 @@ export class ModelPythonClient {
|
||||
/**
|
||||
* Installs dependencies for python client
|
||||
*/
|
||||
private async installDependencies(): Promise<void> {
|
||||
public async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [
|
||||
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
|
||||
}
|
||||
@@ -49,7 +49,6 @@ export class ModelPythonClient {
|
||||
* @param modelPath Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
|
||||
await this.installDependencies();
|
||||
return await this.executeModelParametersScripts(modelPath);
|
||||
}
|
||||
|
||||
@@ -61,6 +60,9 @@ export class ModelPythonClient {
|
||||
'import json',
|
||||
`onnx_model_path = '${modelFolderPath}'`,
|
||||
`onnx_model = onnx.load_model(onnx_model_path)`,
|
||||
`type_list = ['undefined',
|
||||
'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double',
|
||||
'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`,
|
||||
`type_map = {
|
||||
onnx.TensorProto.DataType.FLOAT: 'real',
|
||||
onnx.TensorProto.DataType.UINT8: 'tinyint',
|
||||
@@ -76,13 +78,14 @@ export class ModelPythonClient {
|
||||
`def addParameters(list, paramType):
|
||||
for id, p in enumerate(list):
|
||||
p_type = ''
|
||||
|
||||
if p.type.tensor_type.elem_type in type_map:
|
||||
p_type = type_map[p.type.tensor_type.elem_type]
|
||||
|
||||
value = p.type.tensor_type.elem_type
|
||||
if value in type_map:
|
||||
p_type = type_map[value]
|
||||
name = type_list[value]
|
||||
parameters[paramType].append({
|
||||
'name': p.name,
|
||||
'type': p_type
|
||||
'type': p_type,
|
||||
'originalType': name
|
||||
})`,
|
||||
|
||||
'addParameters(onnx_model.graph.input, "inputs")',
|
||||
|
||||
@@ -144,12 +144,37 @@ export function getInsertModelQuery(model: ImportedModel, table: DatabaseTable):
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the query for loading model content from database
|
||||
* @param model model information
|
||||
*/
|
||||
export function getModelContentQuery(model: ImportedModel): string {
|
||||
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
|
||||
const len = model.contentLength !== undefined ? model.contentLength : 0;
|
||||
const maxLength = 1000;
|
||||
let numberOfColumns = len / maxLength;
|
||||
// The query provider doesn't return the whole file bites if too big. so loading the bites it blocks
|
||||
// and merge together to load the file
|
||||
numberOfColumns = numberOfColumns <= 0 ? 1 : numberOfColumns;
|
||||
let columns: string[] = [];
|
||||
let fileIndex = 0;
|
||||
for (let index = 0; index < numberOfColumns; index++) {
|
||||
const length = fileIndex === 0 ? maxLength + 1 : maxLength;
|
||||
columns.push(`substring(@str, ${fileIndex}, ${length}) as d${index}`);
|
||||
fileIndex = fileIndex + length;
|
||||
}
|
||||
|
||||
if (fileIndex < len) {
|
||||
columns.push(`substring(@str, ${fileIndex}, ${maxLength}) as d${columns.length}`);
|
||||
}
|
||||
return `
|
||||
SELECT model
|
||||
DECLARE @str varbinary(max)
|
||||
|
||||
SELECT @str=model
|
||||
FROM ${threePartTableName}
|
||||
WHERE model_id = ${model.id};
|
||||
|
||||
select ${columns.join(',')}
|
||||
`;
|
||||
}
|
||||
|
||||
@@ -190,6 +215,6 @@ export function getDeleteModelQuery(model: ImportedModel): string {
|
||||
`;
|
||||
}
|
||||
|
||||
export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id';
|
||||
export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id, len(model)';
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user