Machine Learning Model Registry - Iteration1 (#9105)

* Machine learning services extension - model registration wizard
This commit is contained in:
Leila Lali
2020-02-26 09:19:48 -08:00
committed by GitHub
parent 067fcc8dfb
commit ff207859d6
46 changed files with 3990 additions and 210 deletions

View File

@@ -0,0 +1,62 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import * as Models from './interfaces';
import * as Mappers from './mappers';
import * as Parameters from './parameters';
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
export class Artifacts {
private readonly client: AzureMachineLearningWorkspacesContext;
constructor(client: AzureMachineLearningWorkspacesContext) {
this.client = client;
}
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options?: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams): Promise<Models.GetArtifactContentInformation2Response>;
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, callback: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): void;
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams, callback: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): void;
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options?: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams | msRest.ServiceCallback<Models.ArtifactContentInformationDto>, callback?: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): Promise<Models.GetArtifactContentInformation2Response> {
return this.client.sendOperationRequest(
{
subscriptionId,
resourceGroupName,
workspaceName,
origin,
container,
options
},
getArtifactContentInformation2OperationSpec,
callback) as Promise<Models.GetArtifactContentInformation2Response>;
}
}
const serializer = new msRest.Serializer(Mappers);
const getArtifactContentInformation2OperationSpec: msRest.OperationSpec = {
httpMethod: 'GET',
path: 'artifact/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}/artifacts/contentinfo/{origin}/{container}',
urlParameters: [
Parameters.subscriptionId,
Parameters.resourceGroupName,
Parameters.workspaceName,
Parameters.origin,
Parameters.container,
Parameters.apiVersion
],
queryParameters: [
Parameters.projectName0,
Parameters.path1,
Parameters.accountName
],
responses: {
200: {
bodyMapper: Mappers.ArtifactContentInformationDto
},
default: {}
},
serializer
};

View File

@@ -0,0 +1,78 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import * as Models from './interfaces';
import * as Mappers from './mappers';
import * as Parameters from './parameters';
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
export class Assets {
private readonly client: AzureMachineLearningWorkspacesContext;
constructor(client: AzureMachineLearningWorkspacesContext) {
this.client = client;
}
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
options?: msRest.RequestOptionsBase
): Promise<Models.AssetsQueryByIdResponse>;
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
callback: msRest.ServiceCallback<Models.Asset>
): void;
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
options: msRest.RequestOptionsBase,
callback: msRest.ServiceCallback<Models.Asset>
): void;
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
options?: msRest.RequestOptionsBase | msRest.ServiceCallback<Models.Asset>,
callback?: msRest.ServiceCallback<Models.Asset>
): Promise<Models.AssetsQueryByIdResponse> {
return this.client.sendOperationRequest(
{
subscriptionId,
resourceGroup,
workspace,
id,
options
},
queryByIdOperationSpec,
callback
) as Promise<Models.AssetsQueryByIdResponse>;
}
}
const serializer = new msRest.Serializer(Mappers);
const queryByIdOperationSpec: msRest.OperationSpec = {
httpMethod: 'GET',
path:
'modelmanagement/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroup}/providers/Microsoft.MachineLearningServices/workspaces/{workspace}/assets/{id}',
urlParameters: [Parameters.subscriptionId, Parameters.resourceGroup, Parameters.workspace, Parameters.id],
responses: {
200: {
bodyMapper: Mappers.Asset
},
default: {
bodyMapper: Mappers.ModelErrorResponse
}
},
serializer
};

View File

@@ -0,0 +1,296 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../common/apiWrapper';
import * as constants from '../common/constants';
import { azureResource } from '../typings/azure-resource';
import { AzureMachineLearningWorkspaces } from '@azure/arm-machinelearningservices';
import { TokenCredentials } from '@azure/ms-rest-js';
import { WorkspaceModels } from './workspacesModels';
import { AzureMachineLearningWorkspacesOptions, Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel, Asset, IArtifactParts } from './interfaces';
import { Config } from '../configurations/config';
import { Assets } from './assets';
import * as polly from 'polly-js';
import { Artifacts } from './artifacts';
import { HttpClient } from '../common/httpClient';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as path from 'path';
import * as os from 'os';
/**
* Azure Model Service
*/
export class AzureModelRegistryService {
/**
*
*/
constructor(private _apiWrapper: ApiWrapper, private _config: Config, private _httpClient: HttpClient, private _outputChannel: vscode.OutputChannel) {
}
/**
* Returns list of azure accounts
*/
public async getAccounts(): Promise<azdata.Account[]> {
return await this._apiWrapper.getAllAccounts();
}
/**
* Returns list of azure subscriptions
* @param account azure account
*/
public async getSubscriptions(account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[] | undefined> {
const data = <azureResource.GetSubscriptionsResult>await this._apiWrapper.executeCommand(constants.azureSubscriptionsCommand, account, true);
return data?.subscriptions;
}
/**
* Returns list of azure groups
* @param account azure account
* @param subscription azure subscription
*/
public async getGroups(
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[] | undefined> {
const data = <azureResource.GetResourceGroupsResult>await this._apiWrapper.executeCommand(constants.azureResourceGroupsCommand, account, subscription, true);
return data?.resourceGroups;
}
/**
* Returns list of workspaces
* @param account azure account
* @param subscription azure subscription
* @param resourceGroup azure resource group
*/
public async getWorkspaces(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
return await this.fetchWorkspaces(account, subscription, resourceGroup);
}
/**
* Returns list of models
* @param account azure account
* @param subscription azure subscription
* @param resourceGroup azure resource group
* @param workspace azure workspace
*/
public async getModels(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace): Promise<WorkspaceModel[] | undefined> {
return await this.fetchModels(account, subscription, resourceGroup, workspace);
}
/**
* Download an azure model to a temporary location
* @param account azure account
* @param subscription azure subscription
* @param resourceGroup azure resource group
* @param workspace azure workspace
* @param model azure model
*/
public async downloadModel(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel): Promise<string> {
let downloadedFilePath: string = '';
for (const tenant of account.properties.tenants) {
try {
const downloadUrls = await this.getAssetArtifactsDownloadLinks(account, subscription, resourceGroup, workspace, model, tenant);
if (downloadUrls && downloadUrls.length > 0) {
downloadedFilePath = await this.downloadArtifact(downloadUrls[0]);
}
} catch (error) {
console.log(error);
}
}
return downloadedFilePath;
}
/**
* Installs dependencies for the extension
*/
public async downloadArtifact(downloadUrl: string): Promise<string> {
return new Promise<string>((resolve, reject) => {
let msgTaskName = constants.downloadModelMsgTaskName;
this._apiWrapper.startBackgroundOperation({
displayName: msgTaskName,
description: msgTaskName,
isCancelable: false,
operation: async op => {
let tempFilePath: string = '';
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await this._httpClient.download(downloadUrl, tempFilePath, op, this._outputChannel);
op.updateStatus(azdata.TaskStatus.Succeeded);
resolve(tempFilePath);
} catch (error) {
let errorMsg = constants.installDependenciesError(error ? error.message : '');
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
reject(errorMsg);
}
}
});
});
}
private async fetchWorkspaces(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
let resources: Workspace[] = [];
try {
for (const tenant of account.properties.tenants) {
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
const token = tokens[tenant.id].token;
const tokenType = tokens[tenant.id].tokenType;
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id);
let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription();
resources.push(...result);
}
} catch (error) {
}
return resources;
}
private async fetchModels(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace): Promise<WorkspaceModel[]> {
let resources: WorkspaceModel[] = [];
for (const tenant of account.properties.tenants) {
try {
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
const client = await this.getClient(baseUri, account, subscription, tenant);
let modelsClient = new WorkspaceModels(client);
resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || ''));
} catch (error) {
console.log(error);
}
}
return resources;
}
private async fetchModelAsset(
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel,
client: AzureMachineLearningWorkspaces): Promise<Asset> {
const modelId = this.getModelId(model);
let modelsClient = new Assets(client);
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
}
public async getAssetArtifactsDownloadLinks(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel,
tenant: any): Promise<string[]> {
let baseUri = this.getBaseUrl(workspace, this._config.amlModelManagementUrl);
const modelManagementClient = await this.getClient(baseUri, account, subscription, tenant);
const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient);
baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
const experienceClient = await this.getClient(baseUri, account, subscription, tenant);
const artifactClient = new Artifacts(experienceClient);
let downloadLinks: string[] = [];
if (asset && asset.artifacts) {
const downloadLinkPromises: Array<Promise<string>> = [];
for (const artifact of asset.artifacts) {
const parts = artifact.id
? this.getPartsFromAssetIdOrPrefix(artifact.id)
: this.getPartsFromAssetIdOrPrefix(artifact.prefix);
if (parts) {
const promise = polly()
.waitAndRetry(3)
.executeForPromise(
async (): Promise<string> => {
const artifact = await artifactClient.getArtifactContentInformation2(
experienceClient.subscriptionId,
resourceGroup.name,
workspace.name || '',
parts.origin,
parts.container,
{ path: parts.path }
);
if (artifact) {
return artifact.contentUri || '';
} else {
return Promise.reject();
}
}
);
downloadLinkPromises.push(promise);
}
}
try {
downloadLinks = await Promise.all(downloadLinkPromises);
} catch (rejectedPromiseError) {
return rejectedPromiseError;
}
}
return downloadLinks;
}
public getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/;
if (idOrPrefix) {
const parts = artifactRegex.exec(idOrPrefix);
if (parts && parts.length === 4) {
return {
origin: parts[1],
container: parts[2],
path: parts[3]
};
}
}
return undefined;
}
private getBaseUrl(workspace: Workspace, server: string): string {
let baseUri = `https://${workspace.location}.${server}`;
if (workspace.location === 'chinaeast2') {
baseUri = `https://${workspace.location}.${server}`;
}
return baseUri;
}
private async getClient(baseUri: string, account: azdata.Account, subscription: azureResource.AzureResourceSubscription, tenant: any): Promise<AzureMachineLearningWorkspaces> {
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
const token = tokens[tenant.id].token;
const tokenType = tokens[tenant.id].tokenType;
const options: AzureMachineLearningWorkspacesOptions = {
baseUri: baseUri
};
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
client.apiVersion = this._config.amlApiVersion;
return client;
}
private getModelId(model: WorkspaceModel): string {
const amlAssetRegex = /^aml:\/\/asset\/(.+)$/;
const id = model ? amlAssetRegex.exec(model.url || '') : undefined;
return id && id.length === 2 ? id[1] : '';
}
}

View File

@@ -0,0 +1,212 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import { Resource } from '@azure/arm-machinelearningservices/esm/models';
/**
* An interface representing ListWorkspaceModelResult.
*/
export interface ListWorkspaceModelsResult extends Array<WorkspaceModel> {
}
/**
* An interface representing Workspace model
*/
export interface WorkspaceModel extends Resource {
framework?: string;
frameworkVersion?: string;
createdBy?: string;
createdTime?: string;
experimentName?: string;
outputsSchema?: Array<string>;
url?: string;
}
/**
* An interface representing Workspace model list response
*/
export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
/**
* The underlying HTTP response.
*/
_response: msRest.HttpResponse & {
/**
* The response body as text (string format)
*/
bodyAsText: string;
/**
* The response body as parsed JSON or XML
*/
parsedBody: ListWorkspaceModelsResult;
};
};
/**
* An interface representing registered model
*/
export interface RegisteredModel {
id: number,
name: string
}
/**
* The Artifact definition.
*/
export interface ArtifactDetails {
/**
* The Artifact Id.
*/
id?: string;
/**
* The Artifact prefix.
*/
prefix?: string;
}
/**
* @interface
* An interface representing Asset.
* The Asset definition.
*
*/
export interface Asset {
/**
* @member {string} [id] The Asset Id.
*/
id?: string;
/**
* @member {string} [name] The name of the Asset.
*/
name?: string;
/**
* @member {string} [description] The Asset description.
*/
description?: string;
/**
* @member {ArtifactDetails[]} [artifacts] A list of child artifacts.
*/
artifacts?: ArtifactDetails[];
/**
* @member {string[]} [tags] The Asset tag list.
*/
tags?: string[];
/**
* @member {{ [propertyName: string]: string }} [kvTags] The Asset tag
* dictionary. Tags are mutable.
*/
kvTags?: { [propertyName: string]: string };
/**
* @member {{ [propertyName: string]: string }} [properties] The Asset
* property dictionary. Properties are immutable.
*/
properties?: { [propertyName: string]: string };
/**
* @member {string} [runid] The RunId associated with this Asset.
*/
runid?: string;
/**
* @member {string} [projectid] The project Id.
*/
projectid?: string;
/**
* @member {{ [propertyName: string]: string }} [meta] A dictionary
* containing metadata about the Asset.
*/
meta?: { [propertyName: string]: string };
/**
* @member {Date} [createdTime] The time the Asset was created in UTC.
*/
createdTime?: Date;
}
/**
* Contains response data for the queryById operation.
*/
export type AssetsQueryByIdResponse = Asset & {
/**
* The underlying HTTP response.
*/
_response: msRest.HttpResponse & {
/**
* The response body as text (string format)
*/
bodyAsText: string;
/**
* The response body as parsed JSON or XML
*/
parsedBody: Asset;
};
};
export interface IArtifactParts {
origin: string;
container: string;
path: string;
}
/**
* @interface
* An interface representing ArtifactContentInformationDto.
*/
export interface ArtifactContentInformationDto {
/**
* @member {string} [contentUri]
*/
contentUri?: string;
/**
* @member {string} [origin]
*/
origin?: string;
/**
* @member {string} [container]
*/
container?: string;
/**
* @member {string} [path]
*/
path?: string;
}
/**
* Contains response data for the getArtifactContentInformation2 operation.
*/
export type GetArtifactContentInformation2Response = ArtifactContentInformationDto & {
/**
* The underlying HTTP response.
*/
_response: msRest.HttpResponse & {
/**
* The response body as text (string format)
*/
bodyAsText: string;
/**
* The response body as parsed JSON or XML
*/
parsedBody: ArtifactContentInformationDto;
};
};
/**
* @interface
* An interface representing ArtifactAPIGetArtifactContentInformation2OptionalParams.
* Optional Parameters.
*
* @extends RequestOptionsBase
*/
export interface ArtifactAPIGetArtifactContentInformation2OptionalParams extends msRest.RequestOptionsBase {
/**
* @member {string} [projectName]
*/
projectName?: string;
/**
* @member {string} [path]
*/
path?: string;
/**
* @member {string} [accountName]
*/
accountName?: string;
}

View File

@@ -0,0 +1,320 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
export const Resource: msRest.CompositeMapper = {
serializedName: 'Resource',
type: {
name: 'Composite',
className: 'Resource',
modelProperties: {
id: {
readOnly: true,
serializedName: 'id',
type: {
name: 'String'
}
},
name: {
readOnly: true,
serializedName: 'name',
type: {
name: 'String'
}
},
identity: {
readOnly: true,
serializedName: 'identity',
type: {
name: 'Composite',
className: 'Identity'
}
},
location: {
serializedName: 'location',
type: {
name: 'String'
}
},
type: {
readOnly: true,
serializedName: 'type',
type: {
name: 'String'
}
},
tags: {
serializedName: 'tags',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
}
}
}
};
export const ListWorkspaceModelsResult: msRest.CompositeMapper = {
serializedName: 'ListWorkspaceModelsResult',
type: {
name: 'Composite',
className: 'ListWorkspaceModelsResult',
modelProperties: {
value: {
serializedName: '',
type: {
name: 'Sequence',
element: {
type: {
name: 'Composite',
className: 'WorkspaceModel'
}
}
}
},
nextLink: {
serializedName: 'nextLink',
type: {
name: 'String'
}
}
}
}
};
export const WorkspaceModel: msRest.CompositeMapper = {
serializedName: 'WorkspaceModel',
type: {
name: 'Composite',
className: 'WorkspaceModel',
modelProperties: {
...Resource.type.modelProperties,
framework: {
readOnly: true,
serializedName: 'framework',
type: {
name: 'String'
}
},
}
}
};
export const MachineLearningServiceError: msRest.CompositeMapper = {
serializedName: 'MachineLearningServiceError',
type: {
name: 'Composite',
className: 'MachineLearningServiceError',
modelProperties: {
error: {
readOnly: true,
serializedName: 'error',
type: {
name: 'Composite',
className: 'ErrorResponse'
}
}
}
}
};
export const ModelErrorResponse: msRest.CompositeMapper = {
serializedName: 'ModelErrorResponse',
type: {
name: 'Composite',
className: 'ModelErrorResponse',
modelProperties: {
code: {
serializedName: 'code',
type: {
name: 'String'
}
},
statusCode: {
serializedName: 'statusCode',
type: {
name: 'Number'
}
},
message: {
serializedName: 'message',
type: {
name: 'String'
}
},
details: {
serializedName: 'details',
type: {
name: 'Sequence',
element: {
type: {
name: 'Composite',
className: 'ErrorDetails'
}
}
}
}
}
}
};
export const ArtifactDetails: msRest.CompositeMapper = {
serializedName: 'ArtifactDetails',
type: {
name: 'Composite',
className: 'ArtifactDetails',
modelProperties: {
id: {
serializedName: 'id',
type: {
name: 'String'
}
},
prefix: {
serializedName: 'prefix',
type: {
name: 'String'
}
}
}
}
};
export const Asset: msRest.CompositeMapper = {
serializedName: 'Asset',
type: {
name: 'Composite',
className: 'Asset',
modelProperties: {
id: {
serializedName: 'id',
type: {
name: 'String'
}
},
name: {
serializedName: 'name',
type: {
name: 'String'
}
},
description: {
serializedName: 'description',
type: {
name: 'String'
}
},
artifacts: {
serializedName: 'artifacts',
type: {
name: 'Sequence',
element: {
type: {
name: 'Composite',
className: 'ArtifactDetails'
}
}
}
},
tags: {
serializedName: 'tags',
type: {
name: 'Sequence',
element: {
type: {
name: 'String'
}
}
}
},
kvTags: {
serializedName: 'kvTags',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
},
properties: {
serializedName: 'properties',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
},
runid: {
serializedName: 'runid',
type: {
name: 'String'
}
},
projectid: {
serializedName: 'projectid',
type: {
name: 'String'
}
},
meta: {
serializedName: 'meta',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
},
createdTime: {
serializedName: 'createdTime',
type: {
name: 'DateTime'
}
}
}
}
};
export const ArtifactContentInformationDto: msRest.CompositeMapper = {
serializedName: 'ArtifactContentInformationDto',
type: {
name: 'Composite',
className: 'ArtifactContentInformationDto',
modelProperties: {
contentUri: {
serializedName: 'contentUri',
type: {
name: 'String'
}
},
origin: {
serializedName: 'origin',
type: {
name: 'String'
}
},
container: {
serializedName: 'container',
type: {
name: 'String'
}
},
path: {
serializedName: 'path',
type: {
name: 'String'
}
}
}
}
};

View File

@@ -0,0 +1,56 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ProcessService } from '../common/processService';
import { Config } from '../configurations/config';
import { ApiWrapper } from '../common/apiWrapper';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
/**
* Service to import model to database
*/
export class ModelImporter {
/**
*
*/
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config) {
}
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
await this.executeScripts(connection, modelFolderPath);
}
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
const parts = modelFolderPath.split('\\');
modelFolderPath = parts.join('/');
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
if (connection) {
let server = connection.serverName;
const experimentId = `ads_ml_experiment_${UUID.generateUuid()}`;
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}` : '';
let scripts: string[] = [
'import mlflow.onnx',
'import onnx',
'from mlflow.tracking.client import MlflowClient',
`onx = onnx.load("${modelFolderPath}")`,
'client = MlflowClient()',
`exp_name = "${experimentId}"`,
`db_uri_artifact = "mssql+pyodbc://${credential}@${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server"`,
'client.create_experiment(exp_name, artifact_location=db_uri_artifact)',
'mlflow.set_experiment(exp_name)',
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'
];
let pythonExecutable = this._config.pythonExecutable;
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
}
}
}

View File

@@ -0,0 +1,143 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
export const subscriptionId: msRest.OperationURLParameter = {
parameterPath: 'subscriptionId',
mapper: {
required: true,
serializedName: 'subscriptionId',
type: {
name: 'String'
}
}
};
export const resourceGroupName: msRest.OperationURLParameter = {
parameterPath: 'resourceGroupName',
mapper: {
required: true,
serializedName: 'resourceGroupName',
type: {
name: 'String'
}
}
};
export const workspaceName: msRest.OperationURLParameter = {
parameterPath: 'workspaceName',
mapper: {
required: true,
serializedName: 'workspaceName',
type: {
name: 'String'
}
}
};
export const workspace: msRest.OperationURLParameter = {
parameterPath: 'workspace',
mapper: {
required: true,
serializedName: 'workspace',
type: {
name: 'String'
}
}
};
export const resourceGroup: msRest.OperationURLParameter = {
parameterPath: 'resourceGroup',
mapper: {
required: true,
serializedName: 'resourceGroup',
type: {
name: 'String'
}
}
};
export const id: msRest.OperationURLParameter = {
parameterPath: 'id',
mapper: {
required: true,
serializedName: 'id',
type: {
name: 'String'
}
}
};
export const acceptLanguage: msRest.OperationParameter = {
parameterPath: 'acceptLanguage',
mapper: {
serializedName: 'accept-language',
defaultValue: 'en-US',
type: {
name: 'String'
}
}
};
export const apiVersion: msRest.OperationQueryParameter = {
parameterPath: 'apiVersion',
mapper: {
required: true,
serializedName: 'api-version',
type: {
name: 'String'
}
}
};
export const origin: msRest.OperationURLParameter = {
parameterPath: 'origin',
mapper: {
required: true,
serializedName: 'origin',
type: {
name: 'String'
}
}
};
export const container: msRest.OperationURLParameter = {
parameterPath: 'container',
mapper: {
required: true,
serializedName: 'container',
type: {
name: 'String'
}
}
};
export const projectName0: msRest.OperationQueryParameter = {
parameterPath: [
'options',
'projectName'
],
mapper: {
serializedName: 'projectName',
type: {
name: 'String'
}
}
};
export const path1: msRest.OperationQueryParameter = {
parameterPath: [
'options',
'path'
],
mapper: {
serializedName: 'path',
type: {
name: 'String'
}
}
};
export const accountName: msRest.OperationQueryParameter = {
parameterPath: [
'options',
'accountName'
],
mapper: {
serializedName: 'accountName',
type: {
name: 'String'
}
}
};

View File

@@ -0,0 +1,77 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
import { RegisteredModel } from './interfaces';
import { ModelImporter } from './modelImporter';
/**
* Service to registered models
*/
export class RegisteredModelService {
/**
*
*/
constructor(
private _apiWrapper: ApiWrapper,
private _config: Config,
private _queryRunner: QueryRunner,
private _modelImporter: ModelImporter) {
}
public async getRegisteredModels(): Promise<RegisteredModel[]> {
let connection = await this.getCurrentConnection();
let list: RegisteredModel[] = [];
if (connection) {
let result = await this.runRegisteredModelsListQuery(connection);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push({
id: +row[0].displayValue,
name: row[1].displayValue
});
});
}
}
return list;
}
public async registerLocalModel(filePath: string) {
let connection = await this.getCurrentConnection();
if (connection) {
await this._modelImporter.registerModel(connection, filePath);
}
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
} catch {
return undefined;
}
}
private registeredModelsQuery(databaseName: string, tableName: string) {
return `
IF (EXISTS (SELECT name
FROM master.dbo.sysdatabases
WHERE ('[' + name + ']' = '${databaseName}'
OR name = '${databaseName}')))
BEGIN
SELECT artifact_id, artifact_name, group_path, artifact_initial_size from ${databaseName}.${tableName}
WHERE artifact_name like '%.onnx'
END
`;
}
}

View File

@@ -0,0 +1,64 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
import * as Models from './interfaces';
import * as Mappers from './mappers';
import * as Parameters from './parameters';
/**
* Workspace models client
*/
export class WorkspaceModels {
private readonly client: AzureMachineLearningWorkspacesContext;
constructor(client: AzureMachineLearningWorkspacesContext) {
this.client = client;
}
listModels(resourceGroupName: string, workspaceName: string, options?: msRest.RequestOptionsBase): Promise<Models.ListWorkspaceModelsResult>;
listModels(resourceGroupName: string, workspaceName: string, callback: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): void;
listModels(resourceGroupName: string, workspaceName: string, options: msRest.RequestOptionsBase, callback: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): void;
listModels(resourceGroupName: string, workspaceName: string, options?: msRest.RequestOptionsBase | msRest.ServiceCallback<Models.ListWorkspaceModelsResult>, callback?: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): Promise<Models.WorkspacesModelsResponse> {
return this.client.sendOperationRequest(
{
resourceGroupName,
workspaceName,
options
},
listModelsOperationSpec,
callback) as Promise<Models.WorkspacesModelsResponse>;
}
}
const serializer = new msRest.Serializer(Mappers);
const listModelsOperationSpec: msRest.OperationSpec = {
httpMethod: 'GET',
path:
'modelmanagement/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}/models',
urlParameters: [
Parameters.subscriptionId,
Parameters.resourceGroupName,
Parameters.workspaceName
],
queryParameters: [
Parameters.apiVersion
],
headerParameters: [
Parameters.acceptLanguage
],
responses: {
200: {
bodyMapper: Mappers.ListWorkspaceModelsResult
},
default: {
bodyMapper: Mappers.MachineLearningServiceError
}
},
serializer
};