mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-07 17:23:56 -05:00
Machine Learning Services - Model detection in predict wizard (#9609)
* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
@@ -28,10 +28,16 @@ import * as utils from '../common/utils';
|
||||
*/
|
||||
export class AzureModelRegistryService {
|
||||
|
||||
private _amlClient: AzureMachineLearningWorkspaces | undefined;
|
||||
private _modelClient: WorkspaceModels | undefined;
|
||||
/**
|
||||
*
|
||||
* Creates new service
|
||||
*/
|
||||
constructor(private _apiWrapper: ApiWrapper, private _config: Config, private _httpClient: HttpClient, private _outputChannel: vscode.OutputChannel) {
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _config: Config,
|
||||
private _httpClient: HttpClient,
|
||||
private _outputChannel: vscode.OutputChannel) {
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -120,10 +126,18 @@ export class AzureModelRegistryService {
|
||||
return downloadedFilePath;
|
||||
}
|
||||
|
||||
public set AzureMachineLearningClient(value: AzureMachineLearningWorkspaces) {
|
||||
this._amlClient = value;
|
||||
}
|
||||
|
||||
public set ModelClient(value: WorkspaceModels) {
|
||||
this._modelClient = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
* Execute the background task to download the artifact
|
||||
*/
|
||||
public async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
|
||||
private async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
|
||||
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
|
||||
return results && results.length > 0 ? results[0] : constants.noResultError;
|
||||
}
|
||||
@@ -139,15 +153,14 @@ export class AzureModelRegistryService {
|
||||
|
||||
try {
|
||||
for (const tenant of account.properties.tenants) {
|
||||
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
|
||||
const token = tokens[tenant.id].token;
|
||||
const tokenType = tokens[tenant.id].tokenType;
|
||||
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id);
|
||||
const client = await this.getAmlClient(account, subscription, tenant);
|
||||
let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription();
|
||||
resources.push(...result);
|
||||
if (result) {
|
||||
resources.push(...result);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
console.log(error);
|
||||
}
|
||||
return resources;
|
||||
}
|
||||
@@ -161,9 +174,11 @@ export class AzureModelRegistryService {
|
||||
|
||||
for (const tenant of account.properties.tenants) {
|
||||
try {
|
||||
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
|
||||
const client = await this.getClient(baseUri, account, subscription, tenant);
|
||||
let modelsClient = new WorkspaceModels(client);
|
||||
let options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
|
||||
};
|
||||
const client = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
let modelsClient = this.getModelClient(client);
|
||||
resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || ''));
|
||||
|
||||
} catch (error) {
|
||||
@@ -182,22 +197,28 @@ export class AzureModelRegistryService {
|
||||
client: AzureMachineLearningWorkspaces): Promise<Asset> {
|
||||
|
||||
const modelId = this.getModelId(model);
|
||||
let modelsClient = new Assets(client);
|
||||
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
|
||||
if (modelId) {
|
||||
let modelsClient = new Assets(client);
|
||||
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
|
||||
} else {
|
||||
throw Error(constants.invalidModelIdError(model.url));
|
||||
}
|
||||
}
|
||||
|
||||
public async getAssetArtifactsDownloadLinks(
|
||||
private async getAssetArtifactsDownloadLinks(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource,
|
||||
workspace: Workspace,
|
||||
model: WorkspaceModel,
|
||||
tenant: any): Promise<string[]> {
|
||||
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
|
||||
const modelManagementClient = await this.getClient(baseUri, account, subscription, tenant);
|
||||
let options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
|
||||
};
|
||||
const modelManagementClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient);
|
||||
baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
|
||||
const experienceClient = await this.getClient(baseUri, account, subscription, tenant);
|
||||
options.baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
|
||||
const experienceClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
const artifactClient = new Artifacts(experienceClient);
|
||||
let downloadLinks: string[] = [];
|
||||
if (asset && asset.artifacts) {
|
||||
@@ -230,17 +251,19 @@ export class AzureModelRegistryService {
|
||||
downloadLinkPromises.push(promise);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
downloadLinks = await Promise.all(downloadLinkPromises);
|
||||
} catch (rejectedPromiseError) {
|
||||
return rejectedPromiseError;
|
||||
}
|
||||
return downloadLinks;
|
||||
|
||||
} else {
|
||||
throw Error(constants.noArtifactError(model.url));
|
||||
}
|
||||
return downloadLinks;
|
||||
}
|
||||
|
||||
public getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
|
||||
private getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
|
||||
const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/;
|
||||
if (idOrPrefix) {
|
||||
const parts = artifactRegex.exec(idOrPrefix);
|
||||
@@ -263,16 +286,35 @@ export class AzureModelRegistryService {
|
||||
return baseUri;
|
||||
}
|
||||
|
||||
private async getClient(baseUri: string, account: azdata.Account, subscription: azureResource.AzureResourceSubscription, tenant: any): Promise<AzureMachineLearningWorkspaces> {
|
||||
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
|
||||
const token = tokens[tenant.id].token;
|
||||
const tokenType = tokens[tenant.id].tokenType;
|
||||
const options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: baseUri
|
||||
};
|
||||
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
|
||||
client.apiVersion = this._config.amlApiVersion;
|
||||
return client;
|
||||
private getModelClient(amlClient: AzureMachineLearningWorkspaces) {
|
||||
return this._modelClient ?? new WorkspaceModels(amlClient);
|
||||
}
|
||||
|
||||
private async getAmlClient(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
tenant: any,
|
||||
options: AzureMachineLearningWorkspacesOptions | undefined = undefined,
|
||||
apiVersion: string | undefined = undefined): Promise<AzureMachineLearningWorkspaces> {
|
||||
if (this._amlClient) {
|
||||
return this._amlClient;
|
||||
} else {
|
||||
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
|
||||
let token: string = '';
|
||||
let tokenType: string | undefined = undefined;
|
||||
if (tokens && tenant.id in tokens) {
|
||||
const tokenForId = tokens[tenant.id];
|
||||
if (tokenForId) {
|
||||
token = tokenForId.token;
|
||||
tokenType = tokenForId.tokenType;
|
||||
}
|
||||
}
|
||||
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
|
||||
if (apiVersion) {
|
||||
client.apiVersion = apiVersion;
|
||||
}
|
||||
return client;
|
||||
}
|
||||
}
|
||||
|
||||
private getModelId(model: WorkspaceModel): string {
|
||||
|
||||
@@ -9,73 +9,85 @@ import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import { Config } from '../configurations/config';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
|
||||
import { ModelImporter } from './modelImporter';
|
||||
import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces';
|
||||
import { ModelPythonClient } from './modelPythonClient';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
/**
|
||||
* Service to registered models
|
||||
* Service to deployed models
|
||||
*/
|
||||
export class RegisteredModelService {
|
||||
export class DeployedModelService {
|
||||
|
||||
/**
|
||||
*
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _config: Config,
|
||||
private _queryRunner: QueryRunner,
|
||||
private _modelImporter: ModelImporter) {
|
||||
private _modelClient: ModelPythonClient) {
|
||||
}
|
||||
|
||||
public async getRegisteredModels(): Promise<RegisteredModel[]> {
|
||||
/**
|
||||
* Returns deployed models
|
||||
*/
|
||||
public async getDeployedModels(): Promise<RegisteredModel[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: RegisteredModel[] = [];
|
||||
if (connection) {
|
||||
let query = this.getConfigureQuery(connection.databaseName);
|
||||
await this._queryRunner.safeRunQuery(connection, query);
|
||||
query = this.registeredModelsQuery();
|
||||
query = this.getDeployedModelsQuery();
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push(this.loadModelData(row));
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.noConnectionError);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||
return {
|
||||
id: +row[0].displayValue,
|
||||
artifactName: row[1].displayValue,
|
||||
title: row[2].displayValue,
|
||||
description: row[3].displayValue,
|
||||
version: row[4].displayValue,
|
||||
created: row[5].displayValue
|
||||
};
|
||||
}
|
||||
|
||||
public async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
||||
/**
|
||||
* Downloads model
|
||||
* @param model model object
|
||||
*/
|
||||
public async downloadModel(model: RegisteredModel): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
const query = this.getUpdateModelScript(connection.databaseName, model);
|
||||
const query = this.getModelContentQuery(model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
const content = result.rows[0][0].displayValue;
|
||||
return await utils.writeFileFromHex(content);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToSelectError);
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.noConnectionError);
|
||||
}
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
||||
/**
|
||||
* Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(filePath: string): Promise<ModelParameters> {
|
||||
return await this._modelClient.loadModelParameters(filePath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deploys local model
|
||||
* @param filePath model file path
|
||||
* @param details model details
|
||||
*/
|
||||
public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let currentModels = await this.getRegisteredModels();
|
||||
await this._modelImporter.registerModel(connection, filePath);
|
||||
let updatedModels = await this.getRegisteredModels();
|
||||
let currentModels = await this.getDeployedModels();
|
||||
await this._modelClient.deployModel(connection, filePath);
|
||||
let updatedModels = await this.getDeployedModels();
|
||||
if (details && updatedModels.length >= currentModels.length + 1) {
|
||||
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
|
||||
const addedModel = updatedModels[updatedModels.length - 1];
|
||||
@@ -92,16 +104,40 @@ export class RegisteredModelService {
|
||||
}
|
||||
}
|
||||
}
|
||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||
return {
|
||||
id: +row[0].displayValue,
|
||||
artifactName: row[1].displayValue,
|
||||
title: row[2].displayValue,
|
||||
description: row[3].displayValue,
|
||||
version: row[4].displayValue,
|
||||
created: row[5].displayValue
|
||||
};
|
||||
}
|
||||
|
||||
private async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
const query = this.getUpdateModelQuery(connection.databaseName, model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result?.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
}
|
||||
}
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private getConfigureQuery(currentDatabaseName: string): string {
|
||||
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
|
||||
public getConfigureQuery(currentDatabaseName: string): string {
|
||||
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.getConfigureTableQuery());
|
||||
}
|
||||
|
||||
private registeredModelsQuery(): string {
|
||||
public getDeployedModelsQuery(): string {
|
||||
return `
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
@@ -116,7 +152,7 @@ export class RegisteredModelService {
|
||||
* @param databaseName
|
||||
* @param tableName
|
||||
*/
|
||||
private configureTable(): string {
|
||||
public getConfigureTableQuery(): string {
|
||||
let databaseName = this._config.registeredModelDatabaseName;
|
||||
let tableName = this._config.registeredModelTableName;
|
||||
let schemaName = this._config.registeredModelTableSchemaName;
|
||||
@@ -171,7 +207,7 @@ export class RegisteredModelService {
|
||||
`;
|
||||
}
|
||||
|
||||
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
|
||||
let updateScript = `
|
||||
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||
SET
|
||||
@@ -187,4 +223,12 @@ export class RegisteredModelService {
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
|
||||
public getModelContentQuery(model: RegisteredModel): string {
|
||||
return `
|
||||
SELECT artifact_content
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
}
|
||||
@@ -53,6 +53,16 @@ export interface RegisteredModel extends RegisteredModelDetails {
|
||||
artifactName: string;
|
||||
}
|
||||
|
||||
export interface ModelParameter {
|
||||
name: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
export interface ModelParameters {
|
||||
inputs: ModelParameter[],
|
||||
outputs: ModelParameter[]
|
||||
}
|
||||
|
||||
/**
|
||||
* An interface representing registered model
|
||||
*/
|
||||
|
||||
@@ -13,33 +13,89 @@ import * as utils from '../common/utils';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
import * as os from 'os';
|
||||
import { ModelParameters } from './interfaces';
|
||||
|
||||
/**
|
||||
* Service to import model to database
|
||||
* Python client for ONNX models
|
||||
*/
|
||||
export class ModelImporter {
|
||||
export class ModelPythonClient {
|
||||
|
||||
/**
|
||||
*
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
|
||||
}
|
||||
|
||||
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
/**
|
||||
* Deploys models in the SQL database using mlflow
|
||||
* @param connection
|
||||
* @param modelPath
|
||||
*/
|
||||
public async deployModel(connection: azdata.connection.ConnectionProfile, modelPath: string): Promise<void> {
|
||||
await this.installDependencies();
|
||||
await this.executeScripts(connection, modelFolderPath);
|
||||
await this.executeDeployScripts(connection, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for model importer
|
||||
* Installs dependencies for python client
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
private async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
|
||||
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
|
||||
}
|
||||
|
||||
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
/**
|
||||
*
|
||||
* @param modelPath Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
|
||||
await this.installDependencies();
|
||||
return await this.executeModelParametersScripts(modelPath);
|
||||
}
|
||||
|
||||
private async executeModelParametersScripts(modelFolderPath: string): Promise<ModelParameters> {
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
let scripts: string[] = [
|
||||
'import onnx',
|
||||
'import json',
|
||||
`onnx_model_path = '${modelFolderPath}'`,
|
||||
`onnx_model = onnx.load_model(onnx_model_path)`,
|
||||
`type_map = {
|
||||
onnx.TensorProto.DataType.FLOAT: 'real',
|
||||
onnx.TensorProto.DataType.UINT8: 'tinyint',
|
||||
onnx.TensorProto.DataType.INT16: 'smallint',
|
||||
onnx.TensorProto.DataType.INT32: 'int',
|
||||
onnx.TensorProto.DataType.INT64: 'bigint',
|
||||
onnx.TensorProto.DataType.STRING: 'varchar(MAX)',
|
||||
onnx.TensorProto.DataType.DOUBLE: 'float'}`,
|
||||
`parameters = {
|
||||
"inputs": [],
|
||||
"outputs": []
|
||||
}`,
|
||||
`def addParameters(list, paramType):
|
||||
for id, p in enumerate(list):
|
||||
p_type = ''
|
||||
|
||||
if p.type.tensor_type.elem_type in type_map:
|
||||
p_type = type_map[p.type.tensor_type.elem_type]
|
||||
|
||||
parameters[paramType].append({
|
||||
'name': p.name,
|
||||
'type': p_type
|
||||
})`,
|
||||
|
||||
'addParameters(onnx_model.graph.input, "inputs")',
|
||||
'addParameters(onnx_model.graph.output, "outputs")',
|
||||
'print(json.dumps(parameters))'
|
||||
];
|
||||
let pythonExecutable = this._config.pythonExecutable;
|
||||
let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
|
||||
let parametersJson = JSON.parse(output);
|
||||
return Object.assign({}, parametersJson);
|
||||
}
|
||||
|
||||
private async executeDeployScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
let home = utils.makeLinuxPath(os.homedir());
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
Reference in New Issue
Block a user