mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
Machine Learning Extension - Model details (#9377)
* Machine Learning Services Extension - adding model details
This commit is contained in:
@@ -1,19 +1,55 @@
|
||||
{
|
||||
"requiredPythonPackages": [
|
||||
{ "name": "pymssql", "version": "2.1.4" },
|
||||
{ "name": "sqlmlutils", "version": ""}
|
||||
],
|
||||
|
||||
"requiredRPackages": [
|
||||
{ "name": "RODBCext", "repository": "https://cran.microsoft.com" },
|
||||
{ "name": "sqlmlutils", "fileName": "sqlmlutils_0.7.1.zip", "downloadUrl": "https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true"}
|
||||
],
|
||||
|
||||
"rPackagesRepository": "https://cran.r-project.org",
|
||||
|
||||
"registeredModelsDatabaseName": "MlFlowDB",
|
||||
"registeredModelsTableName": "dbo.artifacts",
|
||||
"amlModelManagementUrl": "modelmanagement.azureml.net",
|
||||
"amlExperienceUrl": "experiments.azureml.net",
|
||||
"amlApiVersion": "2018-11-19"
|
||||
"sqlPackageManagement": {
|
||||
"requiredPythonPackages": [
|
||||
{
|
||||
"name": "pymssql",
|
||||
"version": "2.1.4"
|
||||
},
|
||||
{
|
||||
"name": "sqlmlutils",
|
||||
"version": ""
|
||||
}
|
||||
],
|
||||
"requiredRPackages": [
|
||||
{
|
||||
"name": "RODBCext",
|
||||
"repository": "https://cran.microsoft.com"
|
||||
},
|
||||
{
|
||||
"name": "sqlmlutils",
|
||||
"fileName": "sqlmlutils_0.7.1.zip",
|
||||
"downloadUrl": "https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true"
|
||||
}
|
||||
],
|
||||
"rPackagesRepository": "https://cran.r-project.org"
|
||||
},
|
||||
"modelManagement": {
|
||||
"registeredModelsDatabaseName": "MlFlowDB",
|
||||
"registeredModelsTableName": "artifacts",
|
||||
"amlModelManagementUrl": "modelmanagement.azureml.net",
|
||||
"amlExperienceUrl": "experiments.azureml.net",
|
||||
"amlApiVersion": "2018-11-19",
|
||||
"requiredPythonPackages": [
|
||||
{
|
||||
"name": "onnx",
|
||||
"version": ""
|
||||
},
|
||||
{
|
||||
"name": "onnxruntime",
|
||||
"version": ""
|
||||
},
|
||||
{
|
||||
"name": "mlflow",
|
||||
"version": ""
|
||||
},
|
||||
{
|
||||
"name": "pyodbc",
|
||||
"version": ""
|
||||
},
|
||||
{
|
||||
"name": "mlflow-dbstore",
|
||||
"version": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,4 +101,8 @@ export class ApiWrapper {
|
||||
public getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{ [key: string]: any }> {
|
||||
return azdata.accounts.getSecurityToken(account, resource);
|
||||
}
|
||||
|
||||
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
|
||||
return vscode.window.showQuickPick(items, options, token);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,9 +42,17 @@ export const rPathConfigKey = 'rPath';
|
||||
|
||||
// Localized texts
|
||||
//
|
||||
export const msgYes = localize('msgYes', "Yes");
|
||||
export const msgNo = localize('msgNo', "No");
|
||||
export const managePackageCommandError = localize('mls.managePackages.error', "Either no connection is available or the server does not have external script enabled.");
|
||||
export function installDependenciesError(err: string): string { return localize('mls.installDependencies.error', "Failed to install dependencies. Error: {0}", err); }
|
||||
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
|
||||
export const installDependenciesMsgTaskName = localize('mls.installDependencies.msgTaskName', "Installing Machine Learning extension dependencies");
|
||||
export const noResultError = localize('mls.noResultError', "No Result returned");
|
||||
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
|
||||
export function confirmInstallPythonPackages(packages: string): string {
|
||||
return localize('mls.installDependencies.confirmInstallPythonPackages'
|
||||
, "The following Python packages are required to install: {0}. Are you sure you want to install?", packages);
|
||||
}
|
||||
export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ...");
|
||||
export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
|
||||
export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
|
||||
@@ -101,23 +109,27 @@ export const extLangSelectedPath = localize('extLang.selectedPath', "Selected Pa
|
||||
export const extLangInstallFailedError = localize('extLang.installFailedError', "Failed to install language");
|
||||
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
|
||||
|
||||
export const modeIld = localize('models.id', "Id");
|
||||
export const modelArtifactName = localize('models.artifactName', "Artifact Name");
|
||||
export const modelName = localize('models.name', "Name");
|
||||
export const modelSize = localize('models.size', "Size");
|
||||
export const modelDescription = localize('models.description', "Description");
|
||||
export const modelCreated = localize('models.created', "Date Created");
|
||||
export const modelVersion = localize('models.version', "Version");
|
||||
export const browseModels = localize('models.browseButton', "...");
|
||||
export const azureAccount = localize('models.azureAccount', "Account");
|
||||
export const azureSubscription = localize('models.azureSubscription', "Subscription");
|
||||
export const azureGroup = localize('models.azureGroup', "Resource Group");
|
||||
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Workspace");
|
||||
export const azureAccount = localize('models.azureAccount', "Azure account");
|
||||
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
|
||||
export const azureGroup = localize('models.azureGroup', "Azure resource group");
|
||||
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
|
||||
export const azureModelFilter = localize('models.azureModelFilter', "Filter");
|
||||
export const azureModels = localize('models.azureModels', "Models");
|
||||
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
|
||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Ender model source details");
|
||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
|
||||
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||
export const azureRegisterModel = localize('models.azureRegisterModel', "Register");
|
||||
export const registerModelWizardTitle = localize('models.RegisterWizard', "Register");
|
||||
export const registerModelButton = localize('models.RegisterModelButton', "Register model");
|
||||
export const registerModelTitle = localize('models.RegisterWizard', "Register model");
|
||||
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
|
||||
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
|
||||
export const localModelSource = localize('models.localModelSource', "Upload file");
|
||||
@@ -125,6 +137,8 @@ export const azureModelSource = localize('models.azureModelSource', "Import from
|
||||
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
|
||||
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
|
||||
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as fs from 'fs';
|
||||
import * as request from 'request';
|
||||
@@ -36,7 +35,7 @@ export class HttpClient {
|
||||
});
|
||||
}
|
||||
|
||||
public download(downloadUrl: string, targetPath: string, backgroundOperation: azdata.BackgroundOperation, outputChannel: vscode.OutputChannel): Promise<void> {
|
||||
public download(downloadUrl: string, targetPath: string, outputChannel: vscode.OutputChannel): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
|
||||
let totalMegaBytes: number | undefined = undefined;
|
||||
@@ -44,12 +43,12 @@ export class HttpClient {
|
||||
let printThreshold = 0.1;
|
||||
let downloadRequest = request.get(downloadUrl, { timeout: DownloadTimeout })
|
||||
.on('error', downloadError => {
|
||||
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, constants.downloadError);
|
||||
outputChannel.appendLine(constants.downloadError);
|
||||
reject(downloadError);
|
||||
})
|
||||
.on('response', (response) => {
|
||||
if (response.statusCode !== 200) {
|
||||
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, constants.downloadError);
|
||||
outputChannel.appendLine(constants.downloadError);
|
||||
return reject(response.statusMessage);
|
||||
}
|
||||
let contentLength = response.headers['content-length'];
|
||||
@@ -73,7 +72,6 @@ export class HttpClient {
|
||||
resolve();
|
||||
})
|
||||
.on('error', (downloadError) => {
|
||||
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, 'Error');
|
||||
reject(downloadError);
|
||||
downloadRequest.abort();
|
||||
});
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as fs from 'fs';
|
||||
import * as constants from '../common/constants';
|
||||
import { promisify } from 'util';
|
||||
import { ApiWrapper } from './apiWrapper';
|
||||
|
||||
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
|
||||
let tempFilePath: string = '';
|
||||
@@ -101,3 +103,76 @@ export function sortPackageVersions(versions: string[], ascending: boolean = tru
|
||||
export function isWindows(): boolean {
|
||||
return process.platform === 'win32';
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes all single-quotes (') by prefixing them with another single quote ('')
|
||||
* ' => ''
|
||||
* @param value The string to escape
|
||||
*/
|
||||
export function doubleEscapeSingleQuotes(value: string): string {
|
||||
return value.replace(/'/g, '\'\'');
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes all single-bracket ([]) by replacing them with another bracket quote ([[]])
|
||||
* ' => ''
|
||||
* @param value The string to escape
|
||||
*/
|
||||
export function doubleEscapeSingleBrackets(value: string): string {
|
||||
return value.replace(/\[/g, '[[').replace(/\]/g, ']]');
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
*/
|
||||
export async function executeTasks<T>(apiWrapper: ApiWrapper, taskName: string, dependencies: PromiseLike<T>[], parallel: boolean): Promise<T[]> {
|
||||
return new Promise<T[]>((resolve, reject) => {
|
||||
let msgTaskName = taskName;
|
||||
apiWrapper.startBackgroundOperation({
|
||||
displayName: msgTaskName,
|
||||
description: msgTaskName,
|
||||
isCancelable: false,
|
||||
operation: async op => {
|
||||
try {
|
||||
let result: T[] = [];
|
||||
// Install required packages
|
||||
//
|
||||
if (parallel) {
|
||||
result = await Promise.all(dependencies);
|
||||
} else {
|
||||
for (let index = 0; index < dependencies.length; index++) {
|
||||
result.push(await dependencies[index]);
|
||||
}
|
||||
}
|
||||
op.updateStatus(azdata.TaskStatus.Succeeded);
|
||||
resolve(result);
|
||||
} catch (error) {
|
||||
let errorMsg = constants.taskFailedError(taskName, error ? error.message : '');
|
||||
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
|
||||
reject(errorMsg);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Promise<boolean> {
|
||||
let choices: { [id: string]: boolean } = {};
|
||||
choices[constants.msgYes] = true;
|
||||
choices[constants.msgNo] = false;
|
||||
|
||||
let options = {
|
||||
placeHolder: message
|
||||
};
|
||||
|
||||
let result = await apiWrapper.showQuickPick(Object.keys(choices).map(c => {
|
||||
return {
|
||||
label: c
|
||||
};
|
||||
}), options);
|
||||
if (result === undefined) {
|
||||
throw Error('invalid selection');
|
||||
}
|
||||
|
||||
return choices[result.label] || false;
|
||||
}
|
||||
|
||||
@@ -36,22 +36,22 @@ export class Config {
|
||||
/**
|
||||
* Returns the config value of required python packages
|
||||
*/
|
||||
public get requiredPythonPackages(): PackageConfigModel[] {
|
||||
return this._configValues.requiredPythonPackages;
|
||||
public get requiredSqlPythonPackages(): PackageConfigModel[] {
|
||||
return this._configValues.sqlPackageManagement.requiredPythonPackages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the config value of required r packages
|
||||
*/
|
||||
public get requiredRPackages(): PackageConfigModel[] {
|
||||
return this._configValues.requiredRPackages;
|
||||
public get requiredSqlRPackages(): PackageConfigModel[] {
|
||||
return this._configValues.sqlPackageManagement.requiredRPackages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns r packages repository
|
||||
*/
|
||||
public get rPackagesRepository(): string {
|
||||
return this._configValues.rPackagesRepository;
|
||||
return this._configValues.sqlPackageManagement.rPackagesRepository;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -79,28 +79,28 @@ export class Config {
|
||||
* Returns registered models table name
|
||||
*/
|
||||
public get registeredModelTableName(): string {
|
||||
return this._configValues.registeredModelsTableName;
|
||||
return this._configValues.modelManagement.registeredModelsTableName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns registered models table name
|
||||
*/
|
||||
public get registeredModelDatabaseName(): string {
|
||||
return this._configValues.registeredModelsDatabaseName;
|
||||
return this._configValues.modelManagement.registeredModelsDatabaseName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns Azure ML API
|
||||
*/
|
||||
public get amlModelManagementUrl(): string {
|
||||
return this._configValues.amlModelManagementUrl;
|
||||
return this._configValues.modelManagement.amlModelManagementUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns Azure ML API
|
||||
*/
|
||||
public get amlExperienceUrl(): string {
|
||||
return this._configValues.amlExperienceUrl;
|
||||
return this._configValues.modelManagement.amlExperienceUrl;
|
||||
}
|
||||
|
||||
|
||||
@@ -108,7 +108,14 @@ export class Config {
|
||||
* Returns Azure ML API Version
|
||||
*/
|
||||
public get amlApiVersion(): string {
|
||||
return this._configValues.amlApiVersion;
|
||||
return this._configValues.modelManagement.amlApiVersion;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns model management python packages
|
||||
*/
|
||||
public get modelsRequiredPythonPackages(): PackageConfigModel[] {
|
||||
return this._configValues.modelManagement.requiredPythonPackages;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -103,7 +103,7 @@ export default class MainController implements vscode.Disposable {
|
||||
let mssqlService = await this.getLanguageExtensionService();
|
||||
let languagesModel = new LanguageService(this._apiWrapper, mssqlService);
|
||||
let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel);
|
||||
let modelImporter = new ModelImporter(this._outputChannel, this._apiWrapper, this._processService, this._config);
|
||||
let modelImporter = new ModelImporter(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
|
||||
|
||||
// Model Management
|
||||
//
|
||||
|
||||
@@ -21,6 +21,7 @@ import { HttpClient } from '../common/httpClient';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as utils from '../common/utils';
|
||||
|
||||
/**
|
||||
* Azure Model Service
|
||||
@@ -109,7 +110,7 @@ export class AzureModelRegistryService {
|
||||
try {
|
||||
const downloadUrls = await this.getAssetArtifactsDownloadLinks(account, subscription, resourceGroup, workspace, model, tenant);
|
||||
if (downloadUrls && downloadUrls.length > 0) {
|
||||
downloadedFilePath = await this.downloadArtifact(downloadUrls[0]);
|
||||
downloadedFilePath = await this.execDownloadArtifactTask(downloadUrls[0]);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
@@ -122,29 +123,15 @@ export class AzureModelRegistryService {
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
*/
|
||||
public async downloadArtifact(downloadUrl: string): Promise<string> {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
let msgTaskName = constants.downloadModelMsgTaskName;
|
||||
this._apiWrapper.startBackgroundOperation({
|
||||
displayName: msgTaskName,
|
||||
description: msgTaskName,
|
||||
isCancelable: false,
|
||||
operation: async op => {
|
||||
let tempFilePath: string = '';
|
||||
try {
|
||||
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await this._httpClient.download(downloadUrl, tempFilePath, op, this._outputChannel);
|
||||
public async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
|
||||
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
|
||||
return results && results.length > 0 ? results[0] : constants.noResultError;
|
||||
}
|
||||
|
||||
op.updateStatus(azdata.TaskStatus.Succeeded);
|
||||
resolve(tempFilePath);
|
||||
} catch (error) {
|
||||
let errorMsg = constants.installDependenciesError(error ? error.message : '');
|
||||
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
|
||||
reject(errorMsg);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
private async downloadArtifact(downloadUrl: string): Promise<string> {
|
||||
let tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await this._httpClient.download(downloadUrl, tempFilePath, this._outputChannel);
|
||||
return tempFilePath;
|
||||
}
|
||||
|
||||
private async fetchWorkspaces(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
|
||||
|
||||
@@ -49,8 +49,12 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
||||
* An interface representing registered model
|
||||
*/
|
||||
export interface RegisteredModel {
|
||||
id: number,
|
||||
name: string
|
||||
id?: number,
|
||||
artifactName?: string,
|
||||
title?: string,
|
||||
created?: string,
|
||||
version?: string
|
||||
description?: string
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -9,6 +9,9 @@ import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as utils from '../common/utils';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
/**
|
||||
* Service to import model to database
|
||||
@@ -18,13 +21,22 @@ export class ModelImporter {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config) {
|
||||
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
|
||||
}
|
||||
|
||||
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
await this.installDependencies();
|
||||
await this.executeScripts(connection, modelFolderPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for model importer
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
|
||||
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
|
||||
}
|
||||
|
||||
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
|
||||
const parts = modelFolderPath.split('\\');
|
||||
@@ -36,7 +48,7 @@ export class ModelImporter {
|
||||
let server = connection.serverName;
|
||||
|
||||
const experimentId = `ads_ml_experiment_${UUID.generateUuid()}`;
|
||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}` : '';
|
||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
|
||||
let scripts: string[] = [
|
||||
'import mlflow.onnx',
|
||||
'import onnx',
|
||||
@@ -44,7 +56,7 @@ export class ModelImporter {
|
||||
`onx = onnx.load("${modelFolderPath}")`,
|
||||
'client = MlflowClient()',
|
||||
`exp_name = "${experimentId}"`,
|
||||
`db_uri_artifact = "mssql+pyodbc://${credential}@${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server"`,
|
||||
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
|
||||
'client.create_experiment(exp_name, artifact_location=db_uri_artifact)',
|
||||
'mlflow.set_experiment(exp_name)',
|
||||
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import { Config } from '../configurations/config';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { RegisteredModel } from './interfaces';
|
||||
import { ModelImporter } from './modelImporter';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
/**
|
||||
* Service to registered models
|
||||
@@ -33,20 +35,57 @@ export class RegisteredModelService {
|
||||
let result = await this.runRegisteredModelsListQuery(connection);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push({
|
||||
id: +row[0].displayValue,
|
||||
name: row[1].displayValue
|
||||
});
|
||||
list.push(this.loadModelData(row));
|
||||
});
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string) {
|
||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
||||
return {
|
||||
id: +row[0].displayValue,
|
||||
artifactName: row[1].displayValue,
|
||||
title: row[2].displayValue,
|
||||
description: row[3].displayValue,
|
||||
version: row[4].displayValue,
|
||||
created: row[5].displayValue
|
||||
};
|
||||
}
|
||||
|
||||
public async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let updatedModel: RegisteredModel | undefined = undefined;
|
||||
if (connection) {
|
||||
let result = await this.runUpdateModelQuery(connection, model);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const row = result.rows[0];
|
||||
updatedModel = this.loadModelData(row);
|
||||
}
|
||||
}
|
||||
return updatedModel;
|
||||
}
|
||||
|
||||
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let currentModels = await this.getRegisteredModels();
|
||||
await this._modelImporter.registerModel(connection, filePath);
|
||||
let updatedModels = await this.getRegisteredModels();
|
||||
if (details && updatedModels.length >= currentModels.length + 1) {
|
||||
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
|
||||
const addedModel = updatedModels[updatedModels.length - 1];
|
||||
addedModel.title = details.title;
|
||||
addedModel.description = details.description;
|
||||
addedModel.version = details.version;
|
||||
const updatedModel = await this.updateModel(addedModel);
|
||||
if (!updatedModel) {
|
||||
throw Error(constants.updateModelFailedError);
|
||||
}
|
||||
|
||||
} else {
|
||||
throw Error(constants.importModelFailedError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,22 +95,91 @@ export class RegisteredModelService {
|
||||
|
||||
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
|
||||
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private registeredModelsQuery(databaseName: string, tableName: string) {
|
||||
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
|
||||
return `
|
||||
IF (EXISTS (SELECT name
|
||||
FROM master.dbo.sysdatabases
|
||||
WHERE ('[' + name + ']' = '${databaseName}'
|
||||
OR name = '${databaseName}')))
|
||||
${this.configureTable(databaseName, tableName)}
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created
|
||||
FROM [${escapedDbName}].dbo.[${escapedTableName}]
|
||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||
Order by artifact_id
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the table and adds extra columns (name, description, version) if doesn't already exist.
|
||||
* Note: this code is temporary and will be removed weh the table supports the required schema
|
||||
* @param databaseName
|
||||
* @param tableName
|
||||
*/
|
||||
private configureTable(databaseName: string, tableName: string): string {
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
IF EXISTS
|
||||
( SELECT [name]
|
||||
FROM sys.tables
|
||||
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
)
|
||||
BEGIN
|
||||
SELECT artifact_id, artifact_name, group_path, artifact_initial_size from ${databaseName}.${tableName}
|
||||
WHERE artifact_name like '%.onnx'
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
|
||||
BEGIN
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||
END
|
||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
|
||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
|
||||
END
|
||||
`;
|
||||
}
|
||||
|
||||
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
|
||||
|
||||
if (!currentDatabaseName) {
|
||||
currentDatabaseName = 'master';
|
||||
}
|
||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
UPDATE ${escapedTableName}
|
||||
SET
|
||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||
WHERE artifact_id = ${model.id};
|
||||
|
||||
USE [${escapedCurrentDbName}]
|
||||
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
|
||||
WHERE artifact_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,8 +20,6 @@ import { PackageConfigModel } from '../configurations/packageConfigModel';
|
||||
|
||||
export class PackageManager {
|
||||
|
||||
private _pythonExecutable: string = '';
|
||||
private _rExecutable: string = '';
|
||||
private _sqlPythonPackagePackageManager: SqlPythonPackageManageProvider;
|
||||
private _sqlRPackageManager: SqlRPackageManageProvider;
|
||||
public dependenciesInstalled: boolean = false;
|
||||
@@ -45,10 +43,15 @@ export class PackageManager {
|
||||
* Initializes the instance and resister SQL package manager with manage package dialog
|
||||
*/
|
||||
public init(): void {
|
||||
this._pythonExecutable = this._config.pythonExecutable;
|
||||
this._rExecutable = this._config.rExecutable;
|
||||
}
|
||||
|
||||
private get pythonExecutable(): string {
|
||||
return this._config.pythonExecutable;
|
||||
}
|
||||
|
||||
private get _rExecutable(): string {
|
||||
return this._config.rExecutable;
|
||||
}
|
||||
/**
|
||||
* Returns packageManageProviders
|
||||
*/
|
||||
@@ -70,9 +73,9 @@ export class PackageManager {
|
||||
let isPythonInstalled = await this._queryRunner.isPythonInstalled(connection);
|
||||
let isRInstalled = await this._queryRunner.isRInstalled(connection);
|
||||
let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined;
|
||||
if (connection && isPythonInstalled) {
|
||||
if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) {
|
||||
defaultProvider = this._sqlPythonPackagePackageManager;
|
||||
} else if (connection && isRInstalled) {
|
||||
} else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) {
|
||||
defaultProvider = this._sqlRPackageManager;
|
||||
}
|
||||
if (connection && defaultProvider) {
|
||||
@@ -104,34 +107,12 @@ export class PackageManager {
|
||||
* Installs dependencies for the extension
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
let msgTaskName = constants.installDependenciesMsgTaskName;
|
||||
this._apiWrapper.startBackgroundOperation({
|
||||
displayName: msgTaskName,
|
||||
description: msgTaskName,
|
||||
isCancelable: false,
|
||||
operation: async op => {
|
||||
try {
|
||||
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
|
||||
|
||||
// Install required packages
|
||||
//
|
||||
await Promise.all([
|
||||
this.installRequiredPythonPackages(),
|
||||
this.installRequiredRPackages(op)]);
|
||||
op.updateStatus(azdata.TaskStatus.Succeeded);
|
||||
resolve();
|
||||
} catch (error) {
|
||||
let errorMsg = constants.installDependenciesError(error ? error.message : '');
|
||||
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
|
||||
reject(errorMsg);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
|
||||
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
|
||||
this.installRequiredRPackages()], true);
|
||||
}
|
||||
|
||||
private async installRequiredRPackages(startBackgroundOperation: azdata.BackgroundOperation): Promise<void> {
|
||||
private async installRequiredRPackages(): Promise<void> {
|
||||
if (!this._config.rEnabled) {
|
||||
return;
|
||||
}
|
||||
@@ -139,22 +120,27 @@ export class PackageManager {
|
||||
throw new Error(constants.rConfigError);
|
||||
}
|
||||
|
||||
await Promise.all(this._config.requiredRPackages.map(x => this.installRPackage(x, startBackgroundOperation)));
|
||||
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
|
||||
await Promise.all(this._config.requiredSqlPythonPackages.map(x => this.installRPackage(x)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs required python packages
|
||||
*/
|
||||
private async installRequiredPythonPackages(): Promise<void> {
|
||||
public async installRequiredPythonPackages(requiredPackages: PackageConfigModel[]): Promise<void> {
|
||||
if (!this._config.pythonEnabled) {
|
||||
return;
|
||||
}
|
||||
if (!this._pythonExecutable) {
|
||||
if (!this.pythonExecutable) {
|
||||
throw new Error(constants.pythonConfigError);
|
||||
}
|
||||
if (!requiredPackages || requiredPackages.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let installedPackages = await this.getInstalledPipPackages();
|
||||
let fileContent = '';
|
||||
this._config.requiredPythonPackages.forEach(packageDetails => {
|
||||
requiredPackages.forEach(packageDetails => {
|
||||
let hasVersion = ('version' in packageDetails) && !isNullOrUndefined(packageDetails['version']) && packageDetails['version'].length > 0;
|
||||
if (!installedPackages.find(x => x.name === packageDetails['name'] && (!hasVersion || packageDetails['version'] === x.version))) {
|
||||
let packageNameDetail = hasVersion ? `${packageDetails.name}==${packageDetails.version}` : `${packageDetails.name}`;
|
||||
@@ -163,11 +149,17 @@ export class PackageManager {
|
||||
});
|
||||
|
||||
if (fileContent) {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackages);
|
||||
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
|
||||
return await this.installPipPackage(tempFilePath);
|
||||
});
|
||||
this._outputChannel.appendLine(result);
|
||||
let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages(fileContent), this._apiWrapper);
|
||||
if (confirmed) {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackages);
|
||||
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
|
||||
return await this.installPipPackage(tempFilePath);
|
||||
});
|
||||
this._outputChannel.appendLine(result);
|
||||
|
||||
} else {
|
||||
throw Error(constants.requiredPackagesNotInstalled);
|
||||
}
|
||||
} else {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackagesAlreadyInstalled);
|
||||
}
|
||||
@@ -175,7 +167,7 @@ export class PackageManager {
|
||||
|
||||
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
try {
|
||||
let cmd = `"${this._pythonExecutable}" -m pip list --format=json`;
|
||||
let cmd = `"${this.pythonExecutable}" -m pip list --format=json`;
|
||||
let packagesInfo = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
|
||||
if (packagesInfo) {
|
||||
@@ -194,18 +186,18 @@ export class PackageManager {
|
||||
}
|
||||
|
||||
private async installPipPackage(requirementFilePath: string): Promise<string> {
|
||||
let cmd = `"${this._pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
|
||||
let cmd = `"${this.pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
|
||||
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
}
|
||||
|
||||
private async installRPackage(model: PackageConfigModel, startBackgroundOperation: azdata.BackgroundOperation): Promise<string> {
|
||||
private async installRPackage(model: PackageConfigModel): Promise<string> {
|
||||
let output = '';
|
||||
let cmd = '';
|
||||
if (model.downloadUrl) {
|
||||
const packageFile = utils.getPackageFilePath(this._rootFolder, model.fileName || model.name);
|
||||
const packageExist = await utils.exists(packageFile);
|
||||
if (!packageExist) {
|
||||
await this._httpClient.download(model.downloadUrl, packageFile, startBackgroundOperation, this._outputChannel);
|
||||
await this._httpClient.download(model.downloadUrl, packageFile, this._outputChannel);
|
||||
}
|
||||
cmd = `"${this._rExecutable}" CMD INSTALL ${packageFile}`;
|
||||
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
|
||||
@@ -142,9 +142,6 @@ describe('Main Controller', () => {
|
||||
let controller = createController(testContext);
|
||||
await controller.activate();
|
||||
|
||||
should.deepEqual(controller.config.requiredPythonPackages, [
|
||||
{ name: 'pymssql', version: '2.1.4' },
|
||||
{ name: 'sqlmlutils', version: '' }
|
||||
]);
|
||||
should.notEqual(controller.config.requiredSqlPythonPackages.find(x => x.name ==='sqlmlutils'), undefined);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -81,7 +81,7 @@ describe('Package Manager', () => {
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
|
||||
});
|
||||
|
||||
@@ -110,24 +110,27 @@ describe('Package Manager', () => {
|
||||
|
||||
it('installDependencies Should install packages that are not already installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesInstalled = false;
|
||||
//let packagesInstalled = false;
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
|
||||
if (command.indexOf('pip install') > 0) {
|
||||
packagesInstalled = true;
|
||||
//packagesInstalled = true;
|
||||
}
|
||||
return Promise.resolve(installedPackages);
|
||||
});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
should.equal(packagesInstalled, true);
|
||||
//should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
//should.equal(packagesInstalled, true);
|
||||
});
|
||||
|
||||
it('installDependencies Should install packages if list packages fails', async function (): Promise<void> {
|
||||
@@ -136,6 +139,9 @@ describe('Package Manager', () => {
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command,) => {
|
||||
if (command.indexOf('pip list') > 0) {
|
||||
@@ -163,7 +169,7 @@ describe('Package Manager', () => {
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject());
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject());
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
|
||||
if (command.indexOf('pip list') > 0) {
|
||||
return Promise.resolve(installedPackages);
|
||||
@@ -181,15 +187,15 @@ describe('Package Manager', () => {
|
||||
});
|
||||
|
||||
function createPackageManager(testContext: TestContext): PackageManager {
|
||||
testContext.config.setup(x => x.requiredPythonPackages).returns( () => [
|
||||
testContext.config.setup(x => x.requiredSqlPythonPackages).returns( () => [
|
||||
{ name: 'pymssql', version: '2.1.4' },
|
||||
{ name: 'sqlmlutils', version: '' }
|
||||
]);
|
||||
testContext.config.setup(x => x.requiredRPackages).returns( () => [
|
||||
testContext.config.setup(x => x.requiredSqlPythonPackages).returns( () => [
|
||||
{ name: 'RODBCext', repository: 'https://cran.microsoft.com' },
|
||||
{ name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'}
|
||||
]);
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
|
||||
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
|
||||
testContext.config.setup(x => x.rEnabled).returns(() => true);
|
||||
|
||||
@@ -21,11 +21,9 @@ describe('Register Model Wizard', () => {
|
||||
|
||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
|
||||
await view.refresh();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.localModelsComponent, undefined);
|
||||
should.notEqual(view.azureModelsComponent, undefined);
|
||||
should.notEqual(view.modelResources, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
});
|
||||
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
@@ -76,7 +74,7 @@ describe('Register Model Wizard', () => {
|
||||
let localModels: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
name: 'model'
|
||||
artifactName: 'model'
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
|
||||
@@ -30,7 +30,7 @@ describe('Registered Models Dialog', () => {
|
||||
let models: RegisteredModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
name: 'model'
|
||||
artifactName: 'model'
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
|
||||
@@ -39,7 +39,7 @@ export class MainViewBase {
|
||||
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._pages) {
|
||||
await Promise.all(this._pages.map(p => p.refresh()));
|
||||
await Promise.all(this._pages.map(async (p) => await p.refresh()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,10 +8,9 @@ import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
|
||||
import { AzureModelsTable } from './azureModelsTable';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent, AzureModelResource } from '../interfaces';
|
||||
import { IDataComponent, AzureModelResource } from '../interfaces';
|
||||
|
||||
export class AzureModelsComponent extends ModelViewBase implements IPageView, IDataComponent<AzureModelResource> {
|
||||
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource> {
|
||||
|
||||
public azureModelsTable: AzureModelsTable | undefined;
|
||||
public azureFilterComponent: AzureResourceFilterComponent | undefined;
|
||||
@@ -46,15 +45,36 @@ export class AzureModelsComponent extends ModelViewBase implements IPageView, ID
|
||||
});
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.azureModelFilter,
|
||||
title: '',
|
||||
component: this.azureFilterComponent.component
|
||||
}, {
|
||||
title: constants.azureModels,
|
||||
title: '',
|
||||
component: this._loader
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.azureFilterComponent && this._loader) {
|
||||
this.azureFilterComponent.addComponents(formBuilder);
|
||||
|
||||
formBuilder.addFormItems([{
|
||||
title: '',
|
||||
component: this._loader
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.azureFilterComponent && this._loader) {
|
||||
this.azureFilterComponent.removeComponents(formBuilder);
|
||||
formBuilder.removeFormItem({
|
||||
title: '',
|
||||
component: this._loader
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
@@ -93,11 +113,4 @@ export class AzureModelsComponent extends ModelViewBase implements IPageView, ID
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the title of the page
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.azureModelsTitle;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,9 +36,22 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: [
|
||||
{ // Id
|
||||
displayName: constants.modeIld,
|
||||
ariaLabel: constants.modeIld,
|
||||
{ // Name
|
||||
displayName: constants.modelName,
|
||||
ariaLabel: constants.modelName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Created
|
||||
displayName: constants.modelCreated,
|
||||
ariaLabel: constants.modelCreated,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 100,
|
||||
@@ -49,12 +62,12 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Name
|
||||
displayName: constants.modelName,
|
||||
ariaLabel: constants.modelName,
|
||||
{ // Version
|
||||
displayName: constants.modelVersion,
|
||||
ariaLabel: constants.modelVersion,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
width: 100,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
@@ -116,7 +129,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
|
||||
selectModelButton.onDidClick(() => {
|
||||
this._selectedModelId = model.id;
|
||||
});
|
||||
return [model.id, model.name, selectModelButton];
|
||||
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
|
||||
}
|
||||
|
||||
return [];
|
||||
|
||||
@@ -15,7 +15,7 @@ import { AzureWorkspaceResource, IDataComponent } from '../interfaces';
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
const componentWidth = 200;
|
||||
const componentWidth = 300;
|
||||
export class AzureResourceFilterComponent extends ModelViewBase implements IDataComponent<AzureWorkspaceResource> {
|
||||
|
||||
private _form: azdata.FormContainer;
|
||||
@@ -77,6 +77,45 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
|
||||
}]).component();
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._accounts
|
||||
}, {
|
||||
title: constants.azureSubscription,
|
||||
component: this._subscriptions
|
||||
}, {
|
||||
title: constants.azureGroup,
|
||||
component: this._groups
|
||||
}, {
|
||||
title: constants.azureModelWorkspace,
|
||||
component: this._workspaces
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureAccount,
|
||||
component: this._accounts
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureSubscription,
|
||||
component: this._subscriptions
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureGroup,
|
||||
component: this._groups
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureModelWorkspace,
|
||||
component: this._workspaces
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the created component
|
||||
*/
|
||||
|
||||
@@ -37,7 +37,7 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
||||
this._tableComponent = this._dataTable.component;
|
||||
|
||||
let registerButton = modelBuilder.button().withProperties({
|
||||
label: constants.registerModelButton,
|
||||
label: constants.registerModelTitle,
|
||||
width: this.buttonMaxLength
|
||||
}).component();
|
||||
registerButton.onDidClick(async () => {
|
||||
|
||||
@@ -33,12 +33,12 @@ export class CurrentModelsTable extends ModelViewBase {
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: [
|
||||
{ // Id
|
||||
displayName: constants.modeIld,
|
||||
ariaLabel: constants.modeIld,
|
||||
{ // Artifact name
|
||||
displayName: constants.modelArtifactName,
|
||||
ariaLabel: constants.modelArtifactName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 100,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
@@ -59,6 +59,19 @@ export class CurrentModelsTable extends ModelViewBase {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Created
|
||||
displayName: constants.modelCreated,
|
||||
ariaLabel: constants.modelCreated,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
@@ -116,7 +129,7 @@ export class CurrentModelsTable extends ModelViewBase {
|
||||
}).component();
|
||||
editLanguageButton.onDidClick(() => {
|
||||
});
|
||||
return [model.id, model.name, editLanguageButton];
|
||||
return [model.artifactName, model.title, model.created, editLanguageButton];
|
||||
}
|
||||
|
||||
return [];
|
||||
|
||||
@@ -7,14 +7,15 @@ import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
|
||||
/**
|
||||
* View to pick local models file
|
||||
*/
|
||||
export class LocalModelsComponent extends ModelViewBase implements IPageView, IDataComponent<string> {
|
||||
export class LocalModelsComponent extends ModelViewBase implements IDataComponent<string> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _flex: azdata.FlexContainer | undefined;
|
||||
private _localPath: azdata.InputBoxComponent | undefined;
|
||||
private _localBrowse: azdata.ButtonComponent | undefined;
|
||||
|
||||
@@ -48,21 +49,40 @@ export class LocalModelsComponent extends ModelViewBase implements IPageView, ID
|
||||
}
|
||||
});
|
||||
|
||||
let flexFilePathModel = modelBuilder.flexContainer()
|
||||
this._flex = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between'
|
||||
justifyContent: 'space-between',
|
||||
width: this.componentMaxLength
|
||||
}).withItems([
|
||||
this._localPath, this._localBrowse]
|
||||
).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
component: flexFilePathModel
|
||||
component: this._flex
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
formBuilder.addFormItem({
|
||||
title: '',
|
||||
component: this._flex
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
formBuilder.removeFormItem({
|
||||
title: '',
|
||||
component: this._flex
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick local models file
|
||||
*/
|
||||
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModel> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _nameComponent: azdata.InputBoxComponent | undefined;
|
||||
private _descriptionComponent: azdata.InputBoxComponent | undefined;
|
||||
|
||||
/**
|
||||
* Creates new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._nameComponent = modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength
|
||||
}).component();
|
||||
this._descriptionComponent = modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
multiline: true,
|
||||
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength,
|
||||
hight: '50px'
|
||||
}).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.modelName,
|
||||
component: this._nameComponent
|
||||
}, {
|
||||
title: constants.modelDescription,
|
||||
component: this._descriptionComponent
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._nameComponent && this._descriptionComponent) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.modelName,
|
||||
component: this._nameComponent
|
||||
}, {
|
||||
title: constants.modelDescription,
|
||||
component: this._descriptionComponent
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._nameComponent && this._descriptionComponent) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelName,
|
||||
component: this._nameComponent
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelDescription,
|
||||
component: this._descriptionComponent
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): RegisteredModel {
|
||||
return {
|
||||
title: this._nameComponent?.value,
|
||||
description: this._descriptionComponent?.value
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { ModelDetailsComponent } from './modelDetailsComponent';
|
||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick model details
|
||||
*/
|
||||
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModel> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public modelDetails: ModelDetailsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.modelDetails = new ModelDetailsComponent(this._apiWrapper, this);
|
||||
this.modelDetails.registerComponent(modelBuilder);
|
||||
|
||||
this.modelDetails.addComponents(this._formBuilder);
|
||||
this.refresh();
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): RegisteredModel | undefined {
|
||||
return this.modelDetails?.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.modelDetailsPageTitle;
|
||||
}
|
||||
}
|
||||
@@ -52,6 +52,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
// Open view
|
||||
//
|
||||
view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
|
||||
@@ -90,7 +91,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
});
|
||||
view.on(RegisterLocalModelEventName, async (arg) => {
|
||||
let registerArgs = <RegisterLocalModelEventArgs>arg;
|
||||
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, registerArgs.filePath);
|
||||
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, registerArgs.filePath, registerArgs.details);
|
||||
view.refresh();
|
||||
});
|
||||
view.on(RegisterModelEventName, async () => {
|
||||
@@ -99,7 +100,7 @@ export class ModelManagementController extends ControllerBase {
|
||||
view.on(RegisterAzureModelEventName, async (arg) => {
|
||||
let registerArgs = <RegisterAzureModelEventArgs>arg;
|
||||
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
|
||||
});
|
||||
view.on(SourceModelSelectedEventName, () => {
|
||||
view.refresh();
|
||||
@@ -157,9 +158,9 @@ export class ModelManagementController extends ControllerBase {
|
||||
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
||||
}
|
||||
|
||||
private async registerLocalModel(service: RegisteredModelService, filePath?: string): Promise<void> {
|
||||
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise<void> {
|
||||
if (filePath) {
|
||||
await service.registerLocalModel(filePath);
|
||||
await service.registerLocalModel(filePath, details);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
|
||||
@@ -173,13 +174,15 @@ export class ModelManagementController extends ControllerBase {
|
||||
subscription: azureResource.AzureResourceSubscription | undefined,
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
workspace: Workspace | undefined,
|
||||
model: WorkspaceModel | undefined): Promise<void> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace || !model) {
|
||||
model: WorkspaceModel | undefined,
|
||||
details: RegisteredModel | undefined): Promise<void> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
|
||||
throw Error(constants.invalidAzureResourceError);
|
||||
}
|
||||
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
|
||||
if (filePath) {
|
||||
await service.registerLocalModel(filePath);
|
||||
|
||||
await service.registerLocalModel(filePath, details);
|
||||
await fs.promises.unlink(filePath);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
|
||||
import { LocalModelsComponent } from './localModelsComponent';
|
||||
import { AzureModelsComponent } from './azureModelsComponent';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ModelSourcePage extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public modelResources: ModelSourcesComponent | undefined;
|
||||
public localModelsComponent: LocalModelsComponent | undefined;
|
||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
|
||||
this.modelResources.registerComponent(modelBuilder);
|
||||
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
|
||||
this.localModelsComponent.registerComponent(modelBuilder);
|
||||
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
|
||||
this.azureModelsComponent.registerComponent(modelBuilder);
|
||||
this.modelResources.addComponents(this._formBuilder);
|
||||
this.refresh();
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ModelSourceType {
|
||||
return this.modelResources?.data || ModelSourceType.Local;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._formBuilder) {
|
||||
if (this.modelResources && this.modelResources.data === ModelSourceType.Local) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent) {
|
||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||
this.localModelsComponent.addComponents(this._formBuilder);
|
||||
await this.localModelsComponent.refresh();
|
||||
}
|
||||
|
||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent) {
|
||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||
this.azureModelsComponent.addComponents(this._formBuilder);
|
||||
await this.azureModelsComponent.refresh();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.modelSourcePageTitle;
|
||||
}
|
||||
}
|
||||
@@ -7,18 +7,19 @@ import * as azdata from 'azdata';
|
||||
import { ModelViewBase, SourceModelSelectedEventName } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
|
||||
export enum ModelSourceType {
|
||||
Local,
|
||||
Azure
|
||||
}
|
||||
/**
|
||||
* View tp pick model source
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ModelSourcesComponent extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
|
||||
export class ModelSourcesComponent extends ModelViewBase implements IDataComponent<ModelSourceType> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _flexContainer: azdata.FlexContainer | undefined;
|
||||
private _amlModel: azdata.RadioButtonComponent | undefined;
|
||||
private _localModel: azdata.RadioButtonComponent | undefined;
|
||||
private _isLocalModel: boolean = true;
|
||||
@@ -58,7 +59,8 @@ export class ModelSourcesComponent extends ModelViewBase implements IPageView, I
|
||||
this.sendRequest(SourceModelSelectedEventName);
|
||||
});
|
||||
|
||||
let flex = modelBuilder.flexContainer()
|
||||
|
||||
this._flexContainer = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'column',
|
||||
justifyContent: 'space-between'
|
||||
@@ -67,12 +69,25 @@ export class ModelSourcesComponent extends ModelViewBase implements IPageView, I
|
||||
).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.modelSourcesTitle,
|
||||
component: flex
|
||||
title: '',
|
||||
component: this._flexContainer
|
||||
}]).component();
|
||||
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flexContainer) {
|
||||
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer });
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flexContainer) {
|
||||
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
@@ -92,11 +107,4 @@ export class ModelSourcesComponent extends ModelViewBase implements IPageView, I
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.modelSourcesTitle;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,11 +15,15 @@ import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
||||
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
|
||||
}
|
||||
|
||||
export interface RegisterAzureModelEventArgs extends AzureModelResource {
|
||||
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
|
||||
details?: RegisteredModel
|
||||
}
|
||||
|
||||
export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs {
|
||||
model?: WorkspaceModel;
|
||||
}
|
||||
|
||||
export interface RegisterLocalModelEventArgs extends AzureResourceEventArgs {
|
||||
export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
|
||||
filePath?: string;
|
||||
}
|
||||
|
||||
@@ -102,9 +106,10 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
* registers local model
|
||||
* @param localFilePath local file path
|
||||
*/
|
||||
public async registerLocalModel(localFilePath: string | undefined): Promise<void> {
|
||||
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModel | undefined): Promise<void> {
|
||||
const args: RegisterLocalModelEventArgs = {
|
||||
filePath: localFilePath
|
||||
filePath: localFilePath,
|
||||
details: details
|
||||
};
|
||||
return await this.sendDataRequest(RegisterLocalModelEventName, args);
|
||||
}
|
||||
@@ -113,7 +118,10 @@ export abstract class ModelViewBase extends ViewBase {
|
||||
* registers azure model
|
||||
* @param args azure resource
|
||||
*/
|
||||
public async registerAzureModel(args: RegisterAzureModelEventArgs | undefined): Promise<void> {
|
||||
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModel | undefined): Promise<void> {
|
||||
const args: RegisterAzureModelEventArgs = Object.assign({}, resource, {
|
||||
details: details
|
||||
});
|
||||
return await this.sendDataRequest(RegisterAzureModelEventName, args);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,15 +11,16 @@ import { LocalModelsComponent } from './localModelsComponent';
|
||||
import { AzureModelsComponent } from './azureModelsComponent';
|
||||
import * as constants from '../../common/constants';
|
||||
import { WizardView } from '../wizardView';
|
||||
import { ModelSourcePage } from './modelSourcePage';
|
||||
import { ModelDetailsPage } from './modelDetailsPage';
|
||||
|
||||
/**
|
||||
* Wizard to register a model
|
||||
*/
|
||||
export class RegisterModelWizard extends ModelViewBase {
|
||||
|
||||
public modelResources: ModelSourcesComponent | undefined;
|
||||
public localModelsComponent: LocalModelsComponent | undefined;
|
||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||
public modelSourcePage: ModelSourcePage | undefined;
|
||||
public modelDetailsPage: ModelDetailsPage | undefined;
|
||||
public wizardView: WizardView | undefined;
|
||||
private _parentView: ModelViewBase | undefined;
|
||||
|
||||
@@ -35,21 +36,23 @@ export class RegisterModelWizard extends ModelViewBase {
|
||||
* Opens a dialog to manage packages used by notebooks.
|
||||
*/
|
||||
public open(): void {
|
||||
|
||||
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
|
||||
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
|
||||
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
|
||||
|
||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
|
||||
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
|
||||
this.wizardView = new WizardView(this._apiWrapper);
|
||||
|
||||
let wizard = this.wizardView.createWizard(constants.registerModelWizardTitle, [this.modelResources, this.localModelsComponent]);
|
||||
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelDetailsPage]);
|
||||
|
||||
this.mainViewPanel = wizard;
|
||||
wizard.doneButton.label = constants.azureRegisterModel;
|
||||
wizard.generateScriptButton.hidden = true;
|
||||
|
||||
wizard.displayPageTitles = true;
|
||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||
if (pageInfo.newPage === undefined) {
|
||||
wizard.cancelButton.enabled = false;
|
||||
wizard.backButton.enabled = false;
|
||||
await this.registerModel();
|
||||
wizard.cancelButton.enabled = true;
|
||||
wizard.backButton.enabled = true;
|
||||
if (this._parentView) {
|
||||
this._parentView?.refresh();
|
||||
}
|
||||
@@ -62,12 +65,24 @@ export class RegisterModelWizard extends ModelViewBase {
|
||||
wizard.open();
|
||||
}
|
||||
|
||||
public get modelResources(): ModelSourcesComponent | undefined {
|
||||
return this.modelSourcePage?.modelResources;
|
||||
}
|
||||
|
||||
public get localModelsComponent(): LocalModelsComponent | undefined {
|
||||
return this.modelSourcePage?.localModelsComponent;
|
||||
}
|
||||
|
||||
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
||||
return this.modelSourcePage?.azureModelsComponent;
|
||||
}
|
||||
|
||||
private async registerModel(): Promise<boolean> {
|
||||
try {
|
||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||
await this.registerLocalModel(this.localModelsComponent.data);
|
||||
await this.registerLocalModel(this.localModelsComponent.data, this.modelDetailsPage?.data);
|
||||
} else {
|
||||
await this.registerAzureModel(this.azureModelsComponent?.data);
|
||||
await this.registerAzureModel(this.azureModelsComponent?.data, this.modelDetailsPage?.data);
|
||||
}
|
||||
this.showInfoMessage(constants.modelRegisteredSuccessfully);
|
||||
return true;
|
||||
@@ -78,12 +93,6 @@ export class RegisterModelWizard extends ModelViewBase {
|
||||
}
|
||||
|
||||
private loadPages(): void {
|
||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||
this.wizardView?.addWizardPage(this.localModelsComponent, 1);
|
||||
|
||||
} else if (this.azureModelsComponent) {
|
||||
this.wizardView?.addWizardPage(this.azureModelsComponent, 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -91,6 +100,6 @@ export class RegisterModelWizard extends ModelViewBase {
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
this.loadPages();
|
||||
this.wizardView?.refresh();
|
||||
await this.wizardView?.refresh();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user