Machine Learning Services - Model detection in predict wizard (#9609)

* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
Leila Lali
2020-03-25 13:18:19 -07:00
committed by GitHub
parent 176edde2aa
commit ab82c04766
44 changed files with 2265 additions and 376 deletions

View File

@@ -28,10 +28,16 @@ import * as utils from '../common/utils';
*/
export class AzureModelRegistryService {
private _amlClient: AzureMachineLearningWorkspaces | undefined;
private _modelClient: WorkspaceModels | undefined;
/**
*
* Creates new service
*/
constructor(private _apiWrapper: ApiWrapper, private _config: Config, private _httpClient: HttpClient, private _outputChannel: vscode.OutputChannel) {
constructor(
private _apiWrapper: ApiWrapper,
private _config: Config,
private _httpClient: HttpClient,
private _outputChannel: vscode.OutputChannel) {
}
/**
@@ -120,10 +126,18 @@ export class AzureModelRegistryService {
return downloadedFilePath;
}
public set AzureMachineLearningClient(value: AzureMachineLearningWorkspaces) {
this._amlClient = value;
}
public set ModelClient(value: WorkspaceModels) {
this._modelClient = value;
}
/**
* Installs dependencies for the extension
* Execute the background task to download the artifact
*/
public async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
private async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
return results && results.length > 0 ? results[0] : constants.noResultError;
}
@@ -139,15 +153,14 @@ export class AzureModelRegistryService {
try {
for (const tenant of account.properties.tenants) {
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
const token = tokens[tenant.id].token;
const tokenType = tokens[tenant.id].tokenType;
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id);
const client = await this.getAmlClient(account, subscription, tenant);
let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription();
resources.push(...result);
if (result) {
resources.push(...result);
}
}
} catch (error) {
console.log(error);
}
return resources;
}
@@ -161,9 +174,11 @@ export class AzureModelRegistryService {
for (const tenant of account.properties.tenants) {
try {
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
const client = await this.getClient(baseUri, account, subscription, tenant);
let modelsClient = new WorkspaceModels(client);
let options: AzureMachineLearningWorkspacesOptions = {
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
};
const client = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
let modelsClient = this.getModelClient(client);
resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || ''));
} catch (error) {
@@ -182,22 +197,28 @@ export class AzureModelRegistryService {
client: AzureMachineLearningWorkspaces): Promise<Asset> {
const modelId = this.getModelId(model);
let modelsClient = new Assets(client);
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
if (modelId) {
let modelsClient = new Assets(client);
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
} else {
throw Error(constants.invalidModelIdError(model.url));
}
}
public async getAssetArtifactsDownloadLinks(
private async getAssetArtifactsDownloadLinks(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel,
tenant: any): Promise<string[]> {
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
const modelManagementClient = await this.getClient(baseUri, account, subscription, tenant);
let options: AzureMachineLearningWorkspacesOptions = {
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
};
const modelManagementClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient);
baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
const experienceClient = await this.getClient(baseUri, account, subscription, tenant);
options.baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
const experienceClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
const artifactClient = new Artifacts(experienceClient);
let downloadLinks: string[] = [];
if (asset && asset.artifacts) {
@@ -230,17 +251,19 @@ export class AzureModelRegistryService {
downloadLinkPromises.push(promise);
}
}
try {
downloadLinks = await Promise.all(downloadLinkPromises);
} catch (rejectedPromiseError) {
return rejectedPromiseError;
}
return downloadLinks;
} else {
throw Error(constants.noArtifactError(model.url));
}
return downloadLinks;
}
public getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
private getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/;
if (idOrPrefix) {
const parts = artifactRegex.exec(idOrPrefix);
@@ -263,16 +286,35 @@ export class AzureModelRegistryService {
return baseUri;
}
private async getClient(baseUri: string, account: azdata.Account, subscription: azureResource.AzureResourceSubscription, tenant: any): Promise<AzureMachineLearningWorkspaces> {
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
const token = tokens[tenant.id].token;
const tokenType = tokens[tenant.id].tokenType;
const options: AzureMachineLearningWorkspacesOptions = {
baseUri: baseUri
};
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
client.apiVersion = this._config.amlApiVersion;
return client;
private getModelClient(amlClient: AzureMachineLearningWorkspaces) {
return this._modelClient ?? new WorkspaceModels(amlClient);
}
private async getAmlClient(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
tenant: any,
options: AzureMachineLearningWorkspacesOptions | undefined = undefined,
apiVersion: string | undefined = undefined): Promise<AzureMachineLearningWorkspaces> {
if (this._amlClient) {
return this._amlClient;
} else {
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
let token: string = '';
let tokenType: string | undefined = undefined;
if (tokens && tenant.id in tokens) {
const tokenForId = tokens[tenant.id];
if (tokenForId) {
token = tokenForId.token;
tokenType = tokenForId.tokenType;
}
}
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
if (apiVersion) {
client.apiVersion = apiVersion;
}
return client;
}
}
private getModelId(model: WorkspaceModel): string {

View File

@@ -9,73 +9,85 @@ import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
import { ModelImporter } from './modelImporter';
import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces';
import { ModelPythonClient } from './modelPythonClient';
import * as constants from '../common/constants';
/**
* Service to registered models
* Service to deployed models
*/
export class RegisteredModelService {
export class DeployedModelService {
/**
*
* Creates new instance
*/
constructor(
private _apiWrapper: ApiWrapper,
private _config: Config,
private _queryRunner: QueryRunner,
private _modelImporter: ModelImporter) {
private _modelClient: ModelPythonClient) {
}
public async getRegisteredModels(): Promise<RegisteredModel[]> {
/**
* Returns deployed models
*/
public async getDeployedModels(): Promise<RegisteredModel[]> {
let connection = await this.getCurrentConnection();
let list: RegisteredModel[] = [];
if (connection) {
let query = this.getConfigureQuery(connection.databaseName);
await this._queryRunner.safeRunQuery(connection, query);
query = this.registeredModelsQuery();
query = this.getDeployedModelsQuery();
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push(this.loadModelData(row));
});
}
} else {
throw Error(constants.noConnectionError);
}
return list;
}
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
return {
id: +row[0].displayValue,
artifactName: row[1].displayValue,
title: row[2].displayValue,
description: row[3].displayValue,
version: row[4].displayValue,
created: row[5].displayValue
};
}
public async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
/**
* Downloads model
* @param model model object
*/
public async downloadModel(model: RegisteredModel): Promise<string> {
let connection = await this.getCurrentConnection();
let updatedModel: RegisteredModel | undefined = undefined;
if (connection) {
const query = this.getUpdateModelScript(connection.databaseName, model);
const query = this.getModelContentQuery(model);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
const row = result.rows[0];
updatedModel = this.loadModelData(row);
const content = result.rows[0][0].displayValue;
return await utils.writeFileFromHex(content);
} else {
throw Error(constants.invalidModelToSelectError);
}
} else {
throw Error(constants.noConnectionError);
}
return updatedModel;
}
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
/**
* Loads model parameters
*/
public async loadModelParameters(filePath: string): Promise<ModelParameters> {
return await this._modelClient.loadModelParameters(filePath);
}
/**
* Deploys local model
* @param filePath model file path
* @param details model details
*/
public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
let connection = await this.getCurrentConnection();
if (connection) {
let currentModels = await this.getRegisteredModels();
await this._modelImporter.registerModel(connection, filePath);
let updatedModels = await this.getRegisteredModels();
let currentModels = await this.getDeployedModels();
await this._modelClient.deployModel(connection, filePath);
let updatedModels = await this.getDeployedModels();
if (details && updatedModels.length >= currentModels.length + 1) {
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
const addedModel = updatedModels[updatedModels.length - 1];
@@ -92,16 +104,40 @@ export class RegisteredModelService {
}
}
}
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
return {
id: +row[0].displayValue,
artifactName: row[1].displayValue,
title: row[2].displayValue,
description: row[3].displayValue,
version: row[4].displayValue,
created: row[5].displayValue
};
}
private async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
let connection = await this.getCurrentConnection();
let updatedModel: RegisteredModel | undefined = undefined;
if (connection) {
const query = this.getUpdateModelQuery(connection.databaseName, model);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result?.rows && result.rows.length > 0) {
const row = result.rows[0];
updatedModel = this.loadModelData(row);
}
}
return updatedModel;
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
private getConfigureQuery(currentDatabaseName: string): string {
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
public getConfigureQuery(currentDatabaseName: string): string {
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.getConfigureTableQuery());
}
private registeredModelsQuery(): string {
public getDeployedModelsQuery(): string {
return `
SELECT artifact_id, artifact_name, name, description, version, created
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
@@ -116,7 +152,7 @@ export class RegisteredModelService {
* @param databaseName
* @param tableName
*/
private configureTable(): string {
public getConfigureTableQuery(): string {
let databaseName = this._config.registeredModelDatabaseName;
let tableName = this._config.registeredModelTableName;
let schemaName = this._config.registeredModelTableSchemaName;
@@ -171,7 +207,7 @@ export class RegisteredModelService {
`;
}
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
public getUpdateModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
let updateScript = `
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
SET
@@ -187,4 +223,12 @@ export class RegisteredModelService {
WHERE artifact_id = ${model.id};
`;
}
public getModelContentQuery(model: RegisteredModel): string {
return `
SELECT artifact_content
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
WHERE artifact_id = ${model.id};
`;
}
}

View File

@@ -53,6 +53,16 @@ export interface RegisteredModel extends RegisteredModelDetails {
artifactName: string;
}
export interface ModelParameter {
name: string;
type: string;
}
export interface ModelParameters {
inputs: ModelParameter[],
outputs: ModelParameter[]
}
/**
* An interface representing registered model
*/

View File

@@ -13,33 +13,89 @@ import * as utils from '../common/utils';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
import * as os from 'os';
import { ModelParameters } from './interfaces';
/**
* Service to import model to database
* Python client for ONNX models
*/
export class ModelImporter {
export class ModelPythonClient {
/**
*
* Creates new instance
*/
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
}
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
/**
* Deploys models in the SQL database using mlflow
* @param connection
* @param modelPath
*/
public async deployModel(connection: azdata.connection.ConnectionProfile, modelPath: string): Promise<void> {
await this.installDependencies();
await this.executeScripts(connection, modelFolderPath);
await this.executeDeployScripts(connection, modelPath);
}
/**
* Installs dependencies for model importer
* Installs dependencies for python client
*/
public async installDependencies(): Promise<void> {
private async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
}
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
/**
*
* @param modelPath Loads model parameters
*/
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
await this.installDependencies();
return await this.executeModelParametersScripts(modelPath);
}
private async executeModelParametersScripts(modelFolderPath: string): Promise<ModelParameters> {
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
let scripts: string[] = [
'import onnx',
'import json',
`onnx_model_path = '${modelFolderPath}'`,
`onnx_model = onnx.load_model(onnx_model_path)`,
`type_map = {
onnx.TensorProto.DataType.FLOAT: 'real',
onnx.TensorProto.DataType.UINT8: 'tinyint',
onnx.TensorProto.DataType.INT16: 'smallint',
onnx.TensorProto.DataType.INT32: 'int',
onnx.TensorProto.DataType.INT64: 'bigint',
onnx.TensorProto.DataType.STRING: 'varchar(MAX)',
onnx.TensorProto.DataType.DOUBLE: 'float'}`,
`parameters = {
"inputs": [],
"outputs": []
}`,
`def addParameters(list, paramType):
for id, p in enumerate(list):
p_type = ''
if p.type.tensor_type.elem_type in type_map:
p_type = type_map[p.type.tensor_type.elem_type]
parameters[paramType].append({
'name': p.name,
'type': p_type
})`,
'addParameters(onnx_model.graph.input, "inputs")',
'addParameters(onnx_model.graph.output, "outputs")',
'print(json.dumps(parameters))'
];
let pythonExecutable = this._config.pythonExecutable;
let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
let parametersJson = JSON.parse(output);
return Object.assign({}, parametersJson);
}
private async executeDeployScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
let home = utils.makeLinuxPath(os.homedir());
modelFolderPath = utils.makeLinuxPath(modelFolderPath);