Machine Learning Extension - Model details (#9377)

* Machine Learning Services Extension - adding model details
This commit is contained in:
Leila Lali
2020-03-02 12:47:09 -08:00
committed by GitHub
parent c1f6a67829
commit b5b65117a7
30 changed files with 852 additions and 224 deletions

View File

@@ -1,19 +1,55 @@
{
"requiredPythonPackages": [
{ "name": "pymssql", "version": "2.1.4" },
{ "name": "sqlmlutils", "version": ""}
],
"requiredRPackages": [
{ "name": "RODBCext", "repository": "https://cran.microsoft.com" },
{ "name": "sqlmlutils", "fileName": "sqlmlutils_0.7.1.zip", "downloadUrl": "https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true"}
],
"rPackagesRepository": "https://cran.r-project.org",
"registeredModelsDatabaseName": "MlFlowDB",
"registeredModelsTableName": "dbo.artifacts",
"amlModelManagementUrl": "modelmanagement.azureml.net",
"amlExperienceUrl": "experiments.azureml.net",
"amlApiVersion": "2018-11-19"
"sqlPackageManagement": {
"requiredPythonPackages": [
{
"name": "pymssql",
"version": "2.1.4"
},
{
"name": "sqlmlutils",
"version": ""
}
],
"requiredRPackages": [
{
"name": "RODBCext",
"repository": "https://cran.microsoft.com"
},
{
"name": "sqlmlutils",
"fileName": "sqlmlutils_0.7.1.zip",
"downloadUrl": "https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true"
}
],
"rPackagesRepository": "https://cran.r-project.org"
},
"modelManagement": {
"registeredModelsDatabaseName": "MlFlowDB",
"registeredModelsTableName": "artifacts",
"amlModelManagementUrl": "modelmanagement.azureml.net",
"amlExperienceUrl": "experiments.azureml.net",
"amlApiVersion": "2018-11-19",
"requiredPythonPackages": [
{
"name": "onnx",
"version": ""
},
{
"name": "onnxruntime",
"version": ""
},
{
"name": "mlflow",
"version": ""
},
{
"name": "pyodbc",
"version": ""
},
{
"name": "mlflow-dbstore",
"version": ""
}
]
}
}

View File

@@ -101,4 +101,8 @@ export class ApiWrapper {
public getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{ [key: string]: any }> {
return azdata.accounts.getSecurityToken(account, resource);
}
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
return vscode.window.showQuickPick(items, options, token);
}
}

View File

@@ -42,9 +42,17 @@ export const rPathConfigKey = 'rPath';
// Localized texts
//
export const msgYes = localize('msgYes', "Yes");
export const msgNo = localize('msgNo', "No");
export const managePackageCommandError = localize('mls.managePackages.error', "Either no connection is available or the server does not have external script enabled.");
export function installDependenciesError(err: string): string { return localize('mls.installDependencies.error', "Failed to install dependencies. Error: {0}", err); }
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
export const installDependenciesMsgTaskName = localize('mls.installDependencies.msgTaskName', "Installing Machine Learning extension dependencies");
export const noResultError = localize('mls.noResultError', "No Result returned");
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
export function confirmInstallPythonPackages(packages: string): string {
return localize('mls.installDependencies.confirmInstallPythonPackages'
, "The following Python packages are required to install: {0}. Are you sure you want to install?", packages);
}
export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ...");
export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
@@ -101,23 +109,27 @@ export const extLangSelectedPath = localize('extLang.selectedPath', "Selected Pa
export const extLangInstallFailedError = localize('extLang.installFailedError', "Failed to install language");
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
export const modeIld = localize('models.id', "Id");
export const modelArtifactName = localize('models.artifactName', "Artifact Name");
export const modelName = localize('models.name', "Name");
export const modelSize = localize('models.size', "Size");
export const modelDescription = localize('models.description', "Description");
export const modelCreated = localize('models.created', "Date Created");
export const modelVersion = localize('models.version', "Version");
export const browseModels = localize('models.browseButton', "...");
export const azureAccount = localize('models.azureAccount', "Account");
export const azureSubscription = localize('models.azureSubscription', "Subscription");
export const azureGroup = localize('models.azureGroup', "Resource Group");
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Workspace");
export const azureAccount = localize('models.azureAccount', "Azure account");
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
export const azureGroup = localize('models.azureGroup', "Azure resource group");
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
export const azureModelFilter = localize('models.azureModelFilter', "Filter");
export const azureModels = localize('models.azureModels', "Models");
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Ender model source details");
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
export const azureRegisterModel = localize('models.azureRegisterModel', "Register");
export const registerModelWizardTitle = localize('models.RegisterWizard', "Register");
export const registerModelButton = localize('models.RegisterModelButton', "Register model");
export const registerModelTitle = localize('models.RegisterWizard', "Register model");
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
export const localModelSource = localize('models.localModelSource', "Upload file");
@@ -125,6 +137,8 @@ export const azureModelSource = localize('models.azureModelSource', "Import from
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");

View File

@@ -3,7 +3,6 @@
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as fs from 'fs';
import * as request from 'request';
@@ -36,7 +35,7 @@ export class HttpClient {
});
}
public download(downloadUrl: string, targetPath: string, backgroundOperation: azdata.BackgroundOperation, outputChannel: vscode.OutputChannel): Promise<void> {
public download(downloadUrl: string, targetPath: string, outputChannel: vscode.OutputChannel): Promise<void> {
return new Promise((resolve, reject) => {
let totalMegaBytes: number | undefined = undefined;
@@ -44,12 +43,12 @@ export class HttpClient {
let printThreshold = 0.1;
let downloadRequest = request.get(downloadUrl, { timeout: DownloadTimeout })
.on('error', downloadError => {
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, constants.downloadError);
outputChannel.appendLine(constants.downloadError);
reject(downloadError);
})
.on('response', (response) => {
if (response.statusCode !== 200) {
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, constants.downloadError);
outputChannel.appendLine(constants.downloadError);
return reject(response.statusMessage);
}
let contentLength = response.headers['content-length'];
@@ -73,7 +72,6 @@ export class HttpClient {
resolve();
})
.on('error', (downloadError) => {
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, 'Error');
reject(downloadError);
downloadRequest.abort();
});

View File

@@ -3,12 +3,14 @@
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as path from 'path';
import * as os from 'os';
import * as fs from 'fs';
import * as constants from '../common/constants';
import { promisify } from 'util';
import { ApiWrapper } from './apiWrapper';
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
let tempFilePath: string = '';
@@ -101,3 +103,76 @@ export function sortPackageVersions(versions: string[], ascending: boolean = tru
export function isWindows(): boolean {
return process.platform === 'win32';
}
/**
* Escapes all single-quotes (') by prefixing them with another single quote ('')
* ' => ''
* @param value The string to escape
*/
export function doubleEscapeSingleQuotes(value: string): string {
return value.replace(/'/g, '\'\'');
}
/**
* Escapes all single-bracket ([]) by replacing them with another bracket quote ([[]])
* ' => ''
* @param value The string to escape
*/
export function doubleEscapeSingleBrackets(value: string): string {
return value.replace(/\[/g, '[[').replace(/\]/g, ']]');
}
/**
* Installs dependencies for the extension
*/
export async function executeTasks<T>(apiWrapper: ApiWrapper, taskName: string, dependencies: PromiseLike<T>[], parallel: boolean): Promise<T[]> {
return new Promise<T[]>((resolve, reject) => {
let msgTaskName = taskName;
apiWrapper.startBackgroundOperation({
displayName: msgTaskName,
description: msgTaskName,
isCancelable: false,
operation: async op => {
try {
let result: T[] = [];
// Install required packages
//
if (parallel) {
result = await Promise.all(dependencies);
} else {
for (let index = 0; index < dependencies.length; index++) {
result.push(await dependencies[index]);
}
}
op.updateStatus(azdata.TaskStatus.Succeeded);
resolve(result);
} catch (error) {
let errorMsg = constants.taskFailedError(taskName, error ? error.message : '');
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
reject(errorMsg);
}
}
});
});
}
export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Promise<boolean> {
let choices: { [id: string]: boolean } = {};
choices[constants.msgYes] = true;
choices[constants.msgNo] = false;
let options = {
placeHolder: message
};
let result = await apiWrapper.showQuickPick(Object.keys(choices).map(c => {
return {
label: c
};
}), options);
if (result === undefined) {
throw Error('invalid selection');
}
return choices[result.label] || false;
}

View File

@@ -36,22 +36,22 @@ export class Config {
/**
* Returns the config value of required python packages
*/
public get requiredPythonPackages(): PackageConfigModel[] {
return this._configValues.requiredPythonPackages;
public get requiredSqlPythonPackages(): PackageConfigModel[] {
return this._configValues.sqlPackageManagement.requiredPythonPackages;
}
/**
* Returns the config value of required r packages
*/
public get requiredRPackages(): PackageConfigModel[] {
return this._configValues.requiredRPackages;
public get requiredSqlRPackages(): PackageConfigModel[] {
return this._configValues.sqlPackageManagement.requiredRPackages;
}
/**
* Returns r packages repository
*/
public get rPackagesRepository(): string {
return this._configValues.rPackagesRepository;
return this._configValues.sqlPackageManagement.rPackagesRepository;
}
/**
@@ -79,28 +79,28 @@ export class Config {
* Returns registered models table name
*/
public get registeredModelTableName(): string {
return this._configValues.registeredModelsTableName;
return this._configValues.modelManagement.registeredModelsTableName;
}
/**
* Returns registered models table name
*/
public get registeredModelDatabaseName(): string {
return this._configValues.registeredModelsDatabaseName;
return this._configValues.modelManagement.registeredModelsDatabaseName;
}
/**
* Returns Azure ML API
*/
public get amlModelManagementUrl(): string {
return this._configValues.amlModelManagementUrl;
return this._configValues.modelManagement.amlModelManagementUrl;
}
/**
* Returns Azure ML API
*/
public get amlExperienceUrl(): string {
return this._configValues.amlExperienceUrl;
return this._configValues.modelManagement.amlExperienceUrl;
}
@@ -108,7 +108,14 @@ export class Config {
* Returns Azure ML API Version
*/
public get amlApiVersion(): string {
return this._configValues.amlApiVersion;
return this._configValues.modelManagement.amlApiVersion;
}
/**
* Returns model management python packages
*/
public get modelsRequiredPythonPackages(): PackageConfigModel[] {
return this._configValues.modelManagement.requiredPythonPackages;
}
/**

View File

@@ -103,7 +103,7 @@ export default class MainController implements vscode.Disposable {
let mssqlService = await this.getLanguageExtensionService();
let languagesModel = new LanguageService(this._apiWrapper, mssqlService);
let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel);
let modelImporter = new ModelImporter(this._outputChannel, this._apiWrapper, this._processService, this._config);
let modelImporter = new ModelImporter(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
// Model Management
//

View File

@@ -21,6 +21,7 @@ import { HttpClient } from '../common/httpClient';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as path from 'path';
import * as os from 'os';
import * as utils from '../common/utils';
/**
* Azure Model Service
@@ -109,7 +110,7 @@ export class AzureModelRegistryService {
try {
const downloadUrls = await this.getAssetArtifactsDownloadLinks(account, subscription, resourceGroup, workspace, model, tenant);
if (downloadUrls && downloadUrls.length > 0) {
downloadedFilePath = await this.downloadArtifact(downloadUrls[0]);
downloadedFilePath = await this.execDownloadArtifactTask(downloadUrls[0]);
}
} catch (error) {
@@ -122,29 +123,15 @@ export class AzureModelRegistryService {
/**
* Installs dependencies for the extension
*/
public async downloadArtifact(downloadUrl: string): Promise<string> {
return new Promise<string>((resolve, reject) => {
let msgTaskName = constants.downloadModelMsgTaskName;
this._apiWrapper.startBackgroundOperation({
displayName: msgTaskName,
description: msgTaskName,
isCancelable: false,
operation: async op => {
let tempFilePath: string = '';
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await this._httpClient.download(downloadUrl, tempFilePath, op, this._outputChannel);
public async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
return results && results.length > 0 ? results[0] : constants.noResultError;
}
op.updateStatus(azdata.TaskStatus.Succeeded);
resolve(tempFilePath);
} catch (error) {
let errorMsg = constants.installDependenciesError(error ? error.message : '');
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
reject(errorMsg);
}
}
});
});
private async downloadArtifact(downloadUrl: string): Promise<string> {
let tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await this._httpClient.download(downloadUrl, tempFilePath, this._outputChannel);
return tempFilePath;
}
private async fetchWorkspaces(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {

View File

@@ -49,8 +49,12 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
* An interface representing registered model
*/
export interface RegisteredModel {
id: number,
name: string
id?: number,
artifactName?: string,
title?: string,
created?: string,
version?: string
description?: string
}
/**

View File

@@ -9,6 +9,9 @@ import { ApiWrapper } from '../common/apiWrapper';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as utils from '../common/utils';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
/**
* Service to import model to database
@@ -18,13 +21,22 @@ export class ModelImporter {
/**
*
*/
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config) {
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
}
public async registerModel(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
await this.installDependencies();
await this.executeScripts(connection, modelFolderPath);
}
/**
* Installs dependencies for model importer
*/
public async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
}
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
const parts = modelFolderPath.split('\\');
@@ -36,7 +48,7 @@ export class ModelImporter {
let server = connection.serverName;
const experimentId = `ads_ml_experiment_${UUID.generateUuid()}`;
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}` : '';
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
let scripts: string[] = [
'import mlflow.onnx',
'import onnx',
@@ -44,7 +56,7 @@ export class ModelImporter {
`onx = onnx.load("${modelFolderPath}")`,
'client = MlflowClient()',
`exp_name = "${experimentId}"`,
`db_uri_artifact = "mssql+pyodbc://${credential}@${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server"`,
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
'client.create_experiment(exp_name, artifact_location=db_uri_artifact)',
'mlflow.set_experiment(exp_name)',
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'

View File

@@ -6,10 +6,12 @@
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
import { RegisteredModel } from './interfaces';
import { ModelImporter } from './modelImporter';
import * as constants from '../common/constants';
/**
* Service to registered models
@@ -33,20 +35,57 @@ export class RegisteredModelService {
let result = await this.runRegisteredModelsListQuery(connection);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push({
id: +row[0].displayValue,
name: row[1].displayValue
});
list.push(this.loadModelData(row));
});
}
}
return list;
}
public async registerLocalModel(filePath: string) {
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
return {
id: +row[0].displayValue,
artifactName: row[1].displayValue,
title: row[2].displayValue,
description: row[3].displayValue,
version: row[4].displayValue,
created: row[5].displayValue
};
}
public async updateModel(model: RegisteredModel): Promise<RegisteredModel | undefined> {
let connection = await this.getCurrentConnection();
let updatedModel: RegisteredModel | undefined = undefined;
if (connection) {
let result = await this.runUpdateModelQuery(connection, model);
if (result && result.rows && result.rows.length > 0) {
const row = result.rows[0];
updatedModel = this.loadModelData(row);
}
}
return updatedModel;
}
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
let connection = await this.getCurrentConnection();
if (connection) {
let currentModels = await this.getRegisteredModels();
await this._modelImporter.registerModel(connection, filePath);
let updatedModels = await this.getRegisteredModels();
if (details && updatedModels.length >= currentModels.length + 1) {
updatedModels.sort((a, b) => a.id && b.id ? a.id - b.id : 0);
const addedModel = updatedModels[updatedModels.length - 1];
addedModel.title = details.title;
addedModel.description = details.description;
addedModel.version = details.version;
const updatedModel = await this.updateModel(addedModel);
if (!updatedModel) {
throw Error(constants.updateModelFailedError);
}
} else {
throw Error(constants.importModelFailedError);
}
}
}
@@ -56,22 +95,91 @@ export class RegisteredModelService {
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
} catch {
return undefined;
}
}
private registeredModelsQuery(databaseName: string, tableName: string) {
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
} catch {
return undefined;
}
}
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
return `
IF (EXISTS (SELECT name
FROM master.dbo.sysdatabases
WHERE ('[' + name + ']' = '${databaseName}'
OR name = '${databaseName}')))
${this.configureTable(databaseName, tableName)}
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created
FROM [${escapedDbName}].dbo.[${escapedTableName}]
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
Order by artifact_id
`;
}
/**
* Update the table and adds extra columns (name, description, version) if doesn't already exist.
* Note: this code is temporary and will be removed weh the table supports the required schema
* @param databaseName
* @param tableName
*/
private configureTable(databaseName: string, tableName: string): string {
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
return `
USE [${escapedDbName}]
IF EXISTS
( SELECT [name]
FROM sys.tables
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
)
BEGIN
SELECT artifact_id, artifact_name, group_path, artifact_initial_size from ${databaseName}.${tableName}
WHERE artifact_name like '%.onnx'
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
BEGIN
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
END
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
END
`;
}
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
if (!currentDatabaseName) {
currentDatabaseName = 'master';
}
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
return `
USE [${escapedDbName}]
UPDATE ${escapedTableName}
SET
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
WHERE artifact_id = ${model.id};
USE [${escapedCurrentDbName}]
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
WHERE artifact_id = ${model.id};
`;
}
}

View File

@@ -20,8 +20,6 @@ import { PackageConfigModel } from '../configurations/packageConfigModel';
export class PackageManager {
private _pythonExecutable: string = '';
private _rExecutable: string = '';
private _sqlPythonPackagePackageManager: SqlPythonPackageManageProvider;
private _sqlRPackageManager: SqlRPackageManageProvider;
public dependenciesInstalled: boolean = false;
@@ -45,10 +43,15 @@ export class PackageManager {
* Initializes the instance and resister SQL package manager with manage package dialog
*/
public init(): void {
this._pythonExecutable = this._config.pythonExecutable;
this._rExecutable = this._config.rExecutable;
}
private get pythonExecutable(): string {
return this._config.pythonExecutable;
}
private get _rExecutable(): string {
return this._config.rExecutable;
}
/**
* Returns packageManageProviders
*/
@@ -70,9 +73,9 @@ export class PackageManager {
let isPythonInstalled = await this._queryRunner.isPythonInstalled(connection);
let isRInstalled = await this._queryRunner.isRInstalled(connection);
let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined;
if (connection && isPythonInstalled) {
if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) {
defaultProvider = this._sqlPythonPackagePackageManager;
} else if (connection && isRInstalled) {
} else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) {
defaultProvider = this._sqlRPackageManager;
}
if (connection && defaultProvider) {
@@ -104,34 +107,12 @@ export class PackageManager {
* Installs dependencies for the extension
*/
public async installDependencies(): Promise<void> {
return new Promise<void>((resolve, reject) => {
let msgTaskName = constants.installDependenciesMsgTaskName;
this._apiWrapper.startBackgroundOperation({
displayName: msgTaskName,
description: msgTaskName,
isCancelable: false,
operation: async op => {
try {
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
// Install required packages
//
await Promise.all([
this.installRequiredPythonPackages(),
this.installRequiredRPackages(op)]);
op.updateStatus(azdata.TaskStatus.Succeeded);
resolve();
} catch (error) {
let errorMsg = constants.installDependenciesError(error ? error.message : '');
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
reject(errorMsg);
}
}
});
});
await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
this.installRequiredRPackages()], true);
}
private async installRequiredRPackages(startBackgroundOperation: azdata.BackgroundOperation): Promise<void> {
private async installRequiredRPackages(): Promise<void> {
if (!this._config.rEnabled) {
return;
}
@@ -139,22 +120,27 @@ export class PackageManager {
throw new Error(constants.rConfigError);
}
await Promise.all(this._config.requiredRPackages.map(x => this.installRPackage(x, startBackgroundOperation)));
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
await Promise.all(this._config.requiredSqlPythonPackages.map(x => this.installRPackage(x)));
}
/**
* Installs required python packages
*/
private async installRequiredPythonPackages(): Promise<void> {
public async installRequiredPythonPackages(requiredPackages: PackageConfigModel[]): Promise<void> {
if (!this._config.pythonEnabled) {
return;
}
if (!this._pythonExecutable) {
if (!this.pythonExecutable) {
throw new Error(constants.pythonConfigError);
}
if (!requiredPackages || requiredPackages.length === 0) {
return;
}
let installedPackages = await this.getInstalledPipPackages();
let fileContent = '';
this._config.requiredPythonPackages.forEach(packageDetails => {
requiredPackages.forEach(packageDetails => {
let hasVersion = ('version' in packageDetails) && !isNullOrUndefined(packageDetails['version']) && packageDetails['version'].length > 0;
if (!installedPackages.find(x => x.name === packageDetails['name'] && (!hasVersion || packageDetails['version'] === x.version))) {
let packageNameDetail = hasVersion ? `${packageDetails.name}==${packageDetails.version}` : `${packageDetails.name}`;
@@ -163,11 +149,17 @@ export class PackageManager {
});
if (fileContent) {
this._outputChannel.appendLine(constants.installDependenciesPackages);
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
return await this.installPipPackage(tempFilePath);
});
this._outputChannel.appendLine(result);
let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages(fileContent), this._apiWrapper);
if (confirmed) {
this._outputChannel.appendLine(constants.installDependenciesPackages);
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
return await this.installPipPackage(tempFilePath);
});
this._outputChannel.appendLine(result);
} else {
throw Error(constants.requiredPackagesNotInstalled);
}
} else {
this._outputChannel.appendLine(constants.installDependenciesPackagesAlreadyInstalled);
}
@@ -175,7 +167,7 @@ export class PackageManager {
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
try {
let cmd = `"${this._pythonExecutable}" -m pip list --format=json`;
let cmd = `"${this.pythonExecutable}" -m pip list --format=json`;
let packagesInfo = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
if (packagesInfo) {
@@ -194,18 +186,18 @@ export class PackageManager {
}
private async installPipPackage(requirementFilePath: string): Promise<string> {
let cmd = `"${this._pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
let cmd = `"${this.pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
}
private async installRPackage(model: PackageConfigModel, startBackgroundOperation: azdata.BackgroundOperation): Promise<string> {
private async installRPackage(model: PackageConfigModel): Promise<string> {
let output = '';
let cmd = '';
if (model.downloadUrl) {
const packageFile = utils.getPackageFilePath(this._rootFolder, model.fileName || model.name);
const packageExist = await utils.exists(packageFile);
if (!packageExist) {
await this._httpClient.download(model.downloadUrl, packageFile, startBackgroundOperation, this._outputChannel);
await this._httpClient.download(model.downloadUrl, packageFile, this._outputChannel);
}
cmd = `"${this._rExecutable}" CMD INSTALL ${packageFile}`;
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);

View File

@@ -142,9 +142,6 @@ describe('Main Controller', () => {
let controller = createController(testContext);
await controller.activate();
should.deepEqual(controller.config.requiredPythonPackages, [
{ name: 'pymssql', version: '2.1.4' },
{ name: 'sqlmlutils', version: '' }
]);
should.notEqual(controller.config.requiredSqlPythonPackages.find(x => x.name ==='sqlmlutils'), undefined);
});
});

View File

@@ -81,7 +81,7 @@ describe('Package Manager', () => {
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
@@ -110,24 +110,27 @@ describe('Package Manager', () => {
it('installDependencies Should install packages that are not already installed', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
//let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"}
]`;
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
//packagesInstalled = true;
}
return Promise.resolve(installedPackages);
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, true);
//should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
//should.equal(packagesInstalled, true);
});
it('installDependencies Should install packages if list packages fails', async function (): Promise<void> {
@@ -136,6 +139,9 @@ describe('Package Manager', () => {
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command,) => {
if (command.indexOf('pip list') > 0) {
@@ -163,7 +169,7 @@ describe('Package Manager', () => {
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject());
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject());
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip list') > 0) {
return Promise.resolve(installedPackages);
@@ -181,15 +187,15 @@ describe('Package Manager', () => {
});
function createPackageManager(testContext: TestContext): PackageManager {
testContext.config.setup(x => x.requiredPythonPackages).returns( () => [
testContext.config.setup(x => x.requiredSqlPythonPackages).returns( () => [
{ name: 'pymssql', version: '2.1.4' },
{ name: 'sqlmlutils', version: '' }
]);
testContext.config.setup(x => x.requiredRPackages).returns( () => [
testContext.config.setup(x => x.requiredSqlPythonPackages).returns( () => [
{ name: 'RODBCext', repository: 'https://cran.microsoft.com' },
{ name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'}
]);
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
testContext.config.setup(x => x.rEnabled).returns(() => true);

View File

@@ -21,11 +21,9 @@ describe('Register Model Wizard', () => {
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
view.open();
await view.refresh();
should.notEqual(view.wizardView, undefined);
should.notEqual(view.localModelsComponent, undefined);
should.notEqual(view.azureModelsComponent, undefined);
should.notEqual(view.modelResources, undefined);
should.notEqual(view.modelSourcePage, undefined);
});
it('Should load data successfully ', async function (): Promise<void> {
@@ -76,7 +74,7 @@ describe('Register Model Wizard', () => {
let localModels: RegisteredModel[] = [
{
id: 1,
name: 'model'
artifactName: 'model'
}
];
view.on(ListModelsEventName, () => {

View File

@@ -30,7 +30,7 @@ describe('Registered Models Dialog', () => {
let models: RegisteredModel[] = [
{
id: 1,
name: 'model'
artifactName: 'model'
}
];
view.on(ListModelsEventName, () => {

View File

@@ -39,7 +39,7 @@ export class MainViewBase {
public async refresh(): Promise<void> {
if (this._pages) {
await Promise.all(this._pages.map(p => p.refresh()));
await Promise.all(this._pages.map(async (p) => await p.refresh()));
}
}
}

View File

@@ -8,10 +8,9 @@ import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
import { AzureModelsTable } from './azureModelsTable';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent, AzureModelResource } from '../interfaces';
import { IDataComponent, AzureModelResource } from '../interfaces';
export class AzureModelsComponent extends ModelViewBase implements IPageView, IDataComponent<AzureModelResource> {
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource> {
public azureModelsTable: AzureModelsTable | undefined;
public azureFilterComponent: AzureResourceFilterComponent | undefined;
@@ -46,15 +45,36 @@ export class AzureModelsComponent extends ModelViewBase implements IPageView, ID
});
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureModelFilter,
title: '',
component: this.azureFilterComponent.component
}, {
title: constants.azureModels,
title: '',
component: this._loader
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.addComponents(formBuilder);
formBuilder.addFormItems([{
title: '',
component: this._loader
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.removeComponents(formBuilder);
formBuilder.removeFormItem({
title: '',
component: this._loader
});
}
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
@@ -93,11 +113,4 @@ export class AzureModelsComponent extends ModelViewBase implements IPageView, ID
public async refresh(): Promise<void> {
await this.loadData();
}
/**
* Returns the title of the page
*/
public get title(): string {
return constants.azureModelsTitle;
}
}

View File

@@ -36,9 +36,22 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Id
displayName: constants.modeIld,
ariaLabel: constants.modeIld,
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
@@ -49,12 +62,12 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
{ // Version
displayName: constants.modelVersion,
ariaLabel: constants.modelVersion,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
@@ -116,7 +129,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
selectModelButton.onDidClick(() => {
this._selectedModelId = model.id;
});
return [model.id, model.name, selectModelButton];
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
}
return [];

View File

@@ -15,7 +15,7 @@ import { AzureWorkspaceResource, IDataComponent } from '../interfaces';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 200;
const componentWidth = 300;
export class AzureResourceFilterComponent extends ModelViewBase implements IDataComponent<AzureWorkspaceResource> {
private _form: azdata.FormContainer;
@@ -77,6 +77,45 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
}]).component();
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
formBuilder.addFormItems([{
title: constants.azureAccount,
component: this._accounts
}, {
title: constants.azureSubscription,
component: this._subscriptions
}, {
title: constants.azureGroup,
component: this._groups
}, {
title: constants.azureModelWorkspace,
component: this._workspaces
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._accounts
});
formBuilder.removeFormItem({
title: constants.azureSubscription,
component: this._subscriptions
});
formBuilder.removeFormItem({
title: constants.azureGroup,
component: this._groups
});
formBuilder.removeFormItem({
title: constants.azureModelWorkspace,
component: this._workspaces
});
}
}
/**
* Returns the created component
*/

View File

@@ -37,7 +37,7 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
this._tableComponent = this._dataTable.component;
let registerButton = modelBuilder.button().withProperties({
label: constants.registerModelButton,
label: constants.registerModelTitle,
width: this.buttonMaxLength
}).component();
registerButton.onDidClick(async () => {

View File

@@ -33,12 +33,12 @@ export class CurrentModelsTable extends ModelViewBase {
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Id
displayName: constants.modeIld,
ariaLabel: constants.modeIld,
{ // Artifact name
displayName: constants.modelArtifactName,
ariaLabel: constants.modelArtifactName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
@@ -59,6 +59,19 @@ export class CurrentModelsTable extends ModelViewBase {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
@@ -116,7 +129,7 @@ export class CurrentModelsTable extends ModelViewBase {
}).component();
editLanguageButton.onDidClick(() => {
});
return [model.id, model.name, editLanguageButton];
return [model.artifactName, model.title, model.created, editLanguageButton];
}
return [];

View File

@@ -7,14 +7,15 @@ import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { IDataComponent } from '../interfaces';
/**
* View to pick local models file
*/
export class LocalModelsComponent extends ModelViewBase implements IPageView, IDataComponent<string> {
export class LocalModelsComponent extends ModelViewBase implements IDataComponent<string> {
private _form: azdata.FormContainer | undefined;
private _flex: azdata.FlexContainer | undefined;
private _localPath: azdata.InputBoxComponent | undefined;
private _localBrowse: azdata.ButtonComponent | undefined;
@@ -48,21 +49,40 @@ export class LocalModelsComponent extends ModelViewBase implements IPageView, ID
}
});
let flexFilePathModel = modelBuilder.flexContainer()
this._flex = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
justifyContent: 'space-between',
width: this.componentMaxLength
}).withItems([
this._localPath, this._localBrowse]
).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: flexFilePathModel
component: this._flex
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.addFormItem({
title: '',
component: this._flex
});
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.removeFormItem({
title: '',
component: this._flex
});
}
}
/**
* Returns selected data
*/

View File

@@ -0,0 +1,103 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
import { RegisteredModel } from '../../modelManagement/interfaces';
/**
* View to pick local models file
*/
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModel> {
private _form: azdata.FormContainer | undefined;
private _nameComponent: azdata.InputBoxComponent | undefined;
private _descriptionComponent: azdata.InputBoxComponent | undefined;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._nameComponent = modelBuilder.inputBox().withProperties({
value: '',
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength
}).component();
this._descriptionComponent = modelBuilder.inputBox().withProperties({
value: '',
multiline: true,
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength,
hight: '50px'
}).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.modelName,
component: this._nameComponent
}, {
title: constants.modelDescription,
component: this._descriptionComponent
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent) {
formBuilder.addFormItems([{
title: constants.modelName,
component: this._nameComponent
}, {
title: constants.modelDescription,
component: this._descriptionComponent
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent) {
formBuilder.removeFormItem({
title: constants.modelName,
component: this._nameComponent
});
formBuilder.removeFormItem({
title: constants.modelDescription,
component: this._descriptionComponent
});
}
}
/**
* Returns selected data
*/
public get data(): RegisteredModel {
return {
title: this._nameComponent?.value,
description: this._descriptionComponent?.value
};
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { ModelDetailsComponent } from './modelDetailsComponent';
import { RegisteredModel } from '../../modelManagement/interfaces';
/**
* View to pick model details
*/
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModel> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelDetails: ModelDetailsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelDetails = new ModelDetailsComponent(this._apiWrapper, this);
this.modelDetails.registerComponent(modelBuilder);
this.modelDetails.addComponents(this._formBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): RegisteredModel | undefined {
return this.modelDetails?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelDetailsPageTitle;
}
}

View File

@@ -52,6 +52,7 @@ export class ModelManagementController extends ControllerBase {
// Open view
//
view.open();
await view.refresh();
return view;
}
@@ -90,7 +91,7 @@ export class ModelManagementController extends ControllerBase {
});
view.on(RegisterLocalModelEventName, async (arg) => {
let registerArgs = <RegisterLocalModelEventArgs>arg;
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, registerArgs.filePath);
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, registerArgs.filePath, registerArgs.details);
view.refresh();
});
view.on(RegisterModelEventName, async () => {
@@ -99,7 +100,7 @@ export class ModelManagementController extends ControllerBase {
view.on(RegisterAzureModelEventName, async (arg) => {
let registerArgs = <RegisterAzureModelEventArgs>arg;
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
});
view.on(SourceModelSelectedEventName, () => {
view.refresh();
@@ -157,9 +158,9 @@ export class ModelManagementController extends ControllerBase {
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: RegisteredModelService, filePath?: string): Promise<void> {
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise<void> {
if (filePath) {
await service.registerLocalModel(filePath);
await service.registerLocalModel(filePath, details);
} else {
throw Error(constants.invalidModelToRegisterError);
@@ -173,13 +174,15 @@ export class ModelManagementController extends ControllerBase {
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined): Promise<void> {
if (!account || !subscription || !resourceGroup || !workspace || !model) {
model: WorkspaceModel | undefined,
details: RegisteredModel | undefined): Promise<void> {
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
throw Error(constants.invalidAzureResourceError);
}
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
await service.registerLocalModel(filePath);
await service.registerLocalModel(filePath, details);
await fs.promises.unlink(filePath);
} else {
throw Error(constants.invalidModelToRegisterError);

View File

@@ -0,0 +1,92 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
/**
* View to pick model source
*/
export class ModelSourcePage extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelResources: ModelSourcesComponent | undefined;
public localModelsComponent: LocalModelsComponent | undefined;
public azureModelsComponent: AzureModelsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
this.modelResources.registerComponent(modelBuilder);
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
this.localModelsComponent.registerComponent(modelBuilder);
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
this.azureModelsComponent.registerComponent(modelBuilder);
this.modelResources.addComponents(this._formBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelSourceType {
return this.modelResources?.data || ModelSourceType.Local;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.modelResources && this.modelResources.data === ModelSourceType.Local) {
if (this.localModelsComponent && this.azureModelsComponent) {
this.azureModelsComponent.removeComponents(this._formBuilder);
this.localModelsComponent.addComponents(this._formBuilder);
await this.localModelsComponent.refresh();
}
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) {
if (this.localModelsComponent && this.azureModelsComponent) {
this.localModelsComponent.removeComponents(this._formBuilder);
this.azureModelsComponent.addComponents(this._formBuilder);
await this.azureModelsComponent.refresh();
}
}
}
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelSourcePageTitle;
}
}

View File

@@ -7,18 +7,19 @@ import * as azdata from 'azdata';
import { ModelViewBase, SourceModelSelectedEventName } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { IDataComponent } from '../interfaces';
export enum ModelSourceType {
Local,
Azure
}
/**
* View tp pick model source
* View to pick model source
*/
export class ModelSourcesComponent extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
export class ModelSourcesComponent extends ModelViewBase implements IDataComponent<ModelSourceType> {
private _form: azdata.FormContainer | undefined;
private _flexContainer: azdata.FlexContainer | undefined;
private _amlModel: azdata.RadioButtonComponent | undefined;
private _localModel: azdata.RadioButtonComponent | undefined;
private _isLocalModel: boolean = true;
@@ -58,7 +59,8 @@ export class ModelSourcesComponent extends ModelViewBase implements IPageView, I
this.sendRequest(SourceModelSelectedEventName);
});
let flex = modelBuilder.flexContainer()
this._flexContainer = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'column',
justifyContent: 'space-between'
@@ -67,12 +69,25 @@ export class ModelSourcesComponent extends ModelViewBase implements IPageView, I
).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.modelSourcesTitle,
component: flex
title: '',
component: this._flexContainer
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this._flexContainer });
}
}
/**
* Returns selected data
*/
@@ -92,11 +107,4 @@ export class ModelSourcesComponent extends ModelViewBase implements IPageView, I
*/
public async refresh(): Promise<void> {
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelSourcesTitle;
}
}

View File

@@ -15,11 +15,15 @@ import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
}
export interface RegisterAzureModelEventArgs extends AzureModelResource {
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
details?: RegisteredModel
}
export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs {
model?: WorkspaceModel;
}
export interface RegisterLocalModelEventArgs extends AzureResourceEventArgs {
export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
filePath?: string;
}
@@ -102,9 +106,10 @@ export abstract class ModelViewBase extends ViewBase {
* registers local model
* @param localFilePath local file path
*/
public async registerLocalModel(localFilePath: string | undefined): Promise<void> {
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModel | undefined): Promise<void> {
const args: RegisterLocalModelEventArgs = {
filePath: localFilePath
filePath: localFilePath,
details: details
};
return await this.sendDataRequest(RegisterLocalModelEventName, args);
}
@@ -113,7 +118,10 @@ export abstract class ModelViewBase extends ViewBase {
* registers azure model
* @param args azure resource
*/
public async registerAzureModel(args: RegisterAzureModelEventArgs | undefined): Promise<void> {
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModel | undefined): Promise<void> {
const args: RegisterAzureModelEventArgs = Object.assign({}, resource, {
details: details
});
return await this.sendDataRequest(RegisterAzureModelEventName, args);
}

View File

@@ -11,15 +11,16 @@ import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import * as constants from '../../common/constants';
import { WizardView } from '../wizardView';
import { ModelSourcePage } from './modelSourcePage';
import { ModelDetailsPage } from './modelDetailsPage';
/**
* Wizard to register a model
*/
export class RegisterModelWizard extends ModelViewBase {
public modelResources: ModelSourcesComponent | undefined;
public localModelsComponent: LocalModelsComponent | undefined;
public azureModelsComponent: AzureModelsComponent | undefined;
public modelSourcePage: ModelSourcePage | undefined;
public modelDetailsPage: ModelDetailsPage | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
@@ -35,21 +36,23 @@ export class RegisterModelWizard extends ModelViewBase {
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.registerModelWizardTitle, [this.modelResources, this.localModelsComponent]);
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelDetailsPage]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.azureRegisterModel;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
if (pageInfo.newPage === undefined) {
wizard.cancelButton.enabled = false;
wizard.backButton.enabled = false;
await this.registerModel();
wizard.cancelButton.enabled = true;
wizard.backButton.enabled = true;
if (this._parentView) {
this._parentView?.refresh();
}
@@ -62,12 +65,24 @@ export class RegisterModelWizard extends ModelViewBase {
wizard.open();
}
public get modelResources(): ModelSourcesComponent | undefined {
return this.modelSourcePage?.modelResources;
}
public get localModelsComponent(): LocalModelsComponent | undefined {
return this.modelSourcePage?.localModelsComponent;
}
public get azureModelsComponent(): AzureModelsComponent | undefined {
return this.modelSourcePage?.azureModelsComponent;
}
private async registerModel(): Promise<boolean> {
try {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
await this.registerLocalModel(this.localModelsComponent.data);
await this.registerLocalModel(this.localModelsComponent.data, this.modelDetailsPage?.data);
} else {
await this.registerAzureModel(this.azureModelsComponent?.data);
await this.registerAzureModel(this.azureModelsComponent?.data, this.modelDetailsPage?.data);
}
this.showInfoMessage(constants.modelRegisteredSuccessfully);
return true;
@@ -78,12 +93,6 @@ export class RegisterModelWizard extends ModelViewBase {
}
private loadPages(): void {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
this.wizardView?.addWizardPage(this.localModelsComponent, 1);
} else if (this.azureModelsComponent) {
this.wizardView?.addWizardPage(this.azureModelsComponent, 1);
}
}
/**
@@ -91,6 +100,6 @@ export class RegisterModelWizard extends ModelViewBase {
*/
public async refresh(): Promise<void> {
this.loadPages();
this.wizardView?.refresh();
await this.wizardView?.refresh();
}
}