mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 18:46:40 -05:00
@@ -42,12 +42,12 @@
|
|||||||
},
|
},
|
||||||
"machineLearningServices.pythonPath": {
|
"machineLearningServices.pythonPath": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"default": "python",
|
"default": "",
|
||||||
"description": "%mls.pythonPath.description%"
|
"description": "%mls.pythonPath.description%"
|
||||||
},
|
},
|
||||||
"machineLearningServices.rPath": {
|
"machineLearningServices.rPath": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"default": "r",
|
"default": "",
|
||||||
"description": "%mls.rPath.description%"
|
"description": "%mls.rPath.description%"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ export const pythonEnabledConfigKey = 'enablePython';
|
|||||||
export const rEnabledConfigKey = 'enableR';
|
export const rEnabledConfigKey = 'enableR';
|
||||||
export const registeredModelsTableName = 'registeredModelsTableName';
|
export const registeredModelsTableName = 'registeredModelsTableName';
|
||||||
export const rPathConfigKey = 'rPath';
|
export const rPathConfigKey = 'rPath';
|
||||||
|
export const adsPythonBundleVersion = '0.0.1';
|
||||||
|
|
||||||
// Localized texts
|
// Localized texts
|
||||||
//
|
//
|
||||||
@@ -48,8 +49,10 @@ export const msgYes = localize('msgYes', "Yes");
|
|||||||
export const msgNo = localize('msgNo', "No");
|
export const msgNo = localize('msgNo', "No");
|
||||||
export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed.");
|
export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed.");
|
||||||
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
|
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
|
||||||
export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Installing package management dependencies");
|
export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); }
|
||||||
export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Installing model management dependencies");
|
export function cannotFindR(path: string): string { return localize('mls.cannotFindR.error', "Cannot find R executable '{0}'. Please make sure R is installed and configured correctly", path); }
|
||||||
|
export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Verifying package management dependencies");
|
||||||
|
export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Verifying model management dependencies");
|
||||||
export const noResultError = localize('mls.noResultError', "No Result returned");
|
export const noResultError = localize('mls.noResultError', "No Result returned");
|
||||||
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
|
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
|
||||||
export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that.");
|
export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that.");
|
||||||
@@ -122,6 +125,8 @@ export const extLangInstallFailedError = localize('extLang.installFailedError',
|
|||||||
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
|
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
|
||||||
|
|
||||||
export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model");
|
export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model");
|
||||||
|
export const modelsListEmptyMessage = localize('models.modelsListEmptyMessage', "No Models Yet");
|
||||||
|
export const modelsListEmptyDescription = localize('models.modelsListEmptyDescription', "Use import wizard to add models to this table");
|
||||||
export const databaseName = localize('databaseName', "Models database");
|
export const databaseName = localize('databaseName', "Models database");
|
||||||
export const tableName = localize('tableName', "Models table");
|
export const tableName = localize('tableName', "Models table");
|
||||||
export const existingTableName = localize('existingTableName', "Existing table");
|
export const existingTableName = localize('existingTableName', "Existing table");
|
||||||
@@ -195,6 +200,7 @@ export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMism
|
|||||||
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||||
export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!");
|
export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!");
|
||||||
|
export const selectModelsTableMessage = localize('models.selectModelsTableMessage', "Select models table");
|
||||||
export const modelSchemaIsNotAcceptedMessage = localize('models.modelSchemaIsNotAcceptedMessage', "Invalid table structure");
|
export const modelSchemaIsNotAcceptedMessage = localize('models.modelSchemaIsNotAcceptedMessage', "Invalid table structure");
|
||||||
export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); }
|
export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); }
|
||||||
export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
|
export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
|
||||||
|
|||||||
@@ -44,6 +44,15 @@ export async function exists(path: string): Promise<boolean> {
|
|||||||
return promisify(fs.exists)(path);
|
return promisify(fs.exists)(path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function isDirectory(path: string): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
const stat = await fs.promises.lstat(path);
|
||||||
|
return stat.isDirectory();
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function createFolder(dirPath: string): Promise<void> {
|
export async function createFolder(dirPath: string): Promise<void> {
|
||||||
let folderExists = await exists(dirPath);
|
let folderExists = await exists(dirPath);
|
||||||
if (!folderExists) {
|
if (!folderExists) {
|
||||||
@@ -259,3 +268,18 @@ export function getFileName(filePath: string) {
|
|||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getDefaultPythonLocation(): string {
|
||||||
|
|
||||||
|
return path.join(getUserHome() || '', 'azuredatastudio-python',
|
||||||
|
constants.adsPythonBundleVersion,
|
||||||
|
getPythonExeName());
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getPythonExeName(): string {
|
||||||
|
return process.platform === constants.winPlatform ? 'python.exe' : 'bin/python3';
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getUserHome(): string | undefined {
|
||||||
|
return process.env.HOME || process.env.USERPROFILE;
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,10 +9,11 @@ import * as constants from '../common/constants';
|
|||||||
import { promises as fs } from 'fs';
|
import { promises as fs } from 'fs';
|
||||||
import * as path from 'path';
|
import * as path from 'path';
|
||||||
import { PackageConfigModel } from './packageConfigModel';
|
import { PackageConfigModel } from './packageConfigModel';
|
||||||
|
import * as utils from '../common/utils';
|
||||||
|
|
||||||
const configFileName = 'config.json';
|
const configFileName = 'config.json';
|
||||||
const defaultPythonExecutable = 'python';
|
const defaultPythonExecutable = '';
|
||||||
const defaultRExecutable = 'r';
|
const defaultRExecutable = '';
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -57,8 +58,22 @@ export class Config {
|
|||||||
/**
|
/**
|
||||||
* Returns python path from user settings
|
* Returns python path from user settings
|
||||||
*/
|
*/
|
||||||
public get pythonExecutable(): string {
|
public async getPythonExecutable(verify: boolean): Promise<string> {
|
||||||
return this.config.get(constants.pythonPathConfigKey) || defaultPythonExecutable;
|
let executable: string = this.config.get(constants.pythonPathConfigKey) || defaultPythonExecutable;
|
||||||
|
if (!executable) {
|
||||||
|
executable = utils.getDefaultPythonLocation();
|
||||||
|
} else {
|
||||||
|
const exeName = utils.getPythonExeName();
|
||||||
|
const isFolder = await utils.isDirectory(executable);
|
||||||
|
if (isFolder && executable.indexOf(exeName) < 0) {
|
||||||
|
executable = path.join(executable, exeName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let checkExist = executable && executable.toLocaleUpperCase() !== 'PYTHON' && executable.toLocaleUpperCase() !== 'PYTHON3';
|
||||||
|
if (verify && checkExist && !await utils.exists(executable)) {
|
||||||
|
throw new Error(constants.cannotFindPython(executable));
|
||||||
|
}
|
||||||
|
return executable;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -128,8 +143,14 @@ export class Config {
|
|||||||
/**
|
/**
|
||||||
* Returns r path from user settings
|
* Returns r path from user settings
|
||||||
*/
|
*/
|
||||||
public get rExecutable(): string {
|
public async getRExecutable(verify: boolean): Promise<string> {
|
||||||
return this.config.get(constants.rPathConfigKey) || defaultRExecutable;
|
let executable: string = this.config.get(constants.rPathConfigKey) || defaultRExecutable;
|
||||||
|
let checkExist = executable && executable.toLocaleUpperCase() !== 'R';
|
||||||
|
if (verify && checkExist && !await utils.exists(executable)) {
|
||||||
|
throw new Error(constants.cannotFindR(executable));
|
||||||
|
}
|
||||||
|
|
||||||
|
return executable;
|
||||||
}
|
}
|
||||||
|
|
||||||
private get config(): vscode.WorkspaceConfiguration {
|
private get config(): vscode.WorkspaceConfiguration {
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ export class ModelPythonClient {
|
|||||||
'addParameters(onnx_model.graph.output, "outputs")',
|
'addParameters(onnx_model.graph.output, "outputs")',
|
||||||
'print(json.dumps(parameters))'
|
'print(json.dumps(parameters))'
|
||||||
];
|
];
|
||||||
let pythonExecutable = this._config.pythonExecutable;
|
let pythonExecutable = await this._config.getPythonExecutable(true);
|
||||||
let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
|
let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
|
||||||
let parametersJson = JSON.parse(output);
|
let parametersJson = JSON.parse(output);
|
||||||
return Object.assign({}, parametersJson);
|
return Object.assign({}, parametersJson);
|
||||||
@@ -124,7 +124,7 @@ export class ModelPythonClient {
|
|||||||
'mlflow.set_experiment(exp_name)',
|
'mlflow.set_experiment(exp_name)',
|
||||||
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'
|
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'
|
||||||
];
|
];
|
||||||
let pythonExecutable = this._config.pythonExecutable;
|
let pythonExecutable = await this._config.getPythonExecutable(true);
|
||||||
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
|
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,12 +45,12 @@ export class PackageManager {
|
|||||||
public init(): void {
|
public init(): void {
|
||||||
}
|
}
|
||||||
|
|
||||||
private get pythonExecutable(): string {
|
private async getPythonExecutable(): Promise<string> {
|
||||||
return this._config.pythonExecutable;
|
return await this._config.getPythonExecutable(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
private get _rExecutable(): string {
|
private async getRExecutable(): Promise<string> {
|
||||||
return this._config.rExecutable;
|
return await this._config.getRExecutable(true);
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Returns packageManageProviders
|
* Returns packageManageProviders
|
||||||
@@ -123,7 +123,8 @@ export class PackageManager {
|
|||||||
if (!this._config.rEnabled) {
|
if (!this._config.rEnabled) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!this._rExecutable) {
|
let rExecutable = await this.getRExecutable();
|
||||||
|
if (!rExecutable) {
|
||||||
throw new Error(constants.rConfigError);
|
throw new Error(constants.rConfigError);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,7 +140,8 @@ export class PackageManager {
|
|||||||
if (!this._config.pythonEnabled) {
|
if (!this._config.pythonEnabled) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!this.pythonExecutable) {
|
let pythonExecutable = await this.getPythonExecutable();
|
||||||
|
if (!pythonExecutable) {
|
||||||
throw new Error(constants.pythonConfigError);
|
throw new Error(constants.pythonConfigError);
|
||||||
}
|
}
|
||||||
if (!requiredPackages || requiredPackages.length === 0) {
|
if (!requiredPackages || requiredPackages.length === 0) {
|
||||||
@@ -177,7 +179,8 @@ export class PackageManager {
|
|||||||
|
|
||||||
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
|
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||||
try {
|
try {
|
||||||
let cmd = `"${this.pythonExecutable}" -m pip list --format=json`;
|
let pythonExecutable = await this.getPythonExecutable();
|
||||||
|
let cmd = `"${pythonExecutable}" -m pip list --format=json`;
|
||||||
let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined);
|
let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined);
|
||||||
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
|
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
|
||||||
if (packagesInfo && packagesInfo.indexOf(']') > 0) {
|
if (packagesInfo && packagesInfo.indexOf(']') > 0) {
|
||||||
@@ -196,23 +199,25 @@ export class PackageManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async installPipPackage(requirementFilePath: string): Promise<string> {
|
private async installPipPackage(requirementFilePath: string): Promise<string> {
|
||||||
let cmd = `"${this.pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
|
let pythonExecutable = await this.getPythonExecutable();
|
||||||
|
let cmd = `"${pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
|
||||||
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async installRPackage(model: PackageConfigModel): Promise<string> {
|
private async installRPackage(model: PackageConfigModel): Promise<string> {
|
||||||
let output = '';
|
let output = '';
|
||||||
let cmd = '';
|
let cmd = '';
|
||||||
|
let rExecutable = await this.getRExecutable();
|
||||||
if (model.downloadUrl) {
|
if (model.downloadUrl) {
|
||||||
const packageFile = utils.getPackageFilePath(this._rootFolder, model.fileName || model.name);
|
const packageFile = utils.getPackageFilePath(this._rootFolder, model.fileName || model.name);
|
||||||
const packageExist = await utils.exists(packageFile);
|
const packageExist = await utils.exists(packageFile);
|
||||||
if (!packageExist) {
|
if (!packageExist) {
|
||||||
await this._httpClient.download(model.downloadUrl, packageFile, this._outputChannel);
|
await this._httpClient.download(model.downloadUrl, packageFile, this._outputChannel);
|
||||||
}
|
}
|
||||||
cmd = `"${this._rExecutable}" CMD INSTALL ${packageFile}`;
|
cmd = `"${rExecutable}" CMD INSTALL ${packageFile}`;
|
||||||
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||||
} else if (model.repository) {
|
} else if (model.repository) {
|
||||||
cmd = `"${this._rExecutable}" -e "install.packages('${model.name}', repos='${model.repository}')"`;
|
cmd = `"${rExecutable}" -e "install.packages('${model.name}', repos='${model.repository}')"`;
|
||||||
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||||
}
|
}
|
||||||
return output;
|
return output;
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase
|
|||||||
'pkgmanager = sqlmlutils.SQLPackageManager(connection)',
|
'pkgmanager = sqlmlutils.SQLPackageManager(connection)',
|
||||||
pythonCommandScript
|
pythonCommandScript
|
||||||
];
|
];
|
||||||
let pythonExecutable = this._config.pythonExecutable;
|
let pythonExecutable = await this._config.getPythonExecutable(true);
|
||||||
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
|
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ export class SqlRPackageManageProvider extends SqlPackageManageProviderBase impl
|
|||||||
`${rCommandScript}(connectionString = connection, pkgs, scope = "PUBLIC")`,
|
`${rCommandScript}(connectionString = connection, pkgs, scope = "PUBLIC")`,
|
||||||
'q()'
|
'q()'
|
||||||
];
|
];
|
||||||
let rExecutable = this._config.rExecutable;
|
let rExecutable = await this._config.getRExecutable(true);
|
||||||
await this._processService.execScripts(`${rExecutable}`, scripts, ['--vanilla'], this._outputChannel);
|
await this._processService.execScripts(`${rExecutable}`, scripts, ['--vanilla'], this._outputChannel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -176,7 +176,8 @@ AS (
|
|||||||
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
|
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getPredictInputColumnNames(outputColumns, 'p')}
|
${this.getPredictColumnNames(columns, 'predict_input')},
|
||||||
|
${this.getPredictInputColumnNames(outputColumns, 'p')}
|
||||||
FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx)
|
FROM PREDICT(MODEL = @model, DATA = predict_input, runtime=onnx)
|
||||||
WITH (
|
WITH (
|
||||||
${this.getOutputParameters(outputColumns)}
|
${this.getOutputParameters(outputColumns)}
|
||||||
@@ -197,7 +198,8 @@ AS (
|
|||||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')}
|
${this.getPredictColumnNames(columns, 'predict_input')},
|
||||||
|
${this.getPredictInputColumnNames(outputColumns, 'p')}
|
||||||
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx)
|
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input, runtime=onnx)
|
||||||
WITH (
|
WITH (
|
||||||
${this.getOutputParameters(outputColumns)}
|
${this.getOutputParameters(outputColumns)}
|
||||||
@@ -224,12 +226,6 @@ WITH (
|
|||||||
}).join(',\n');
|
}).join(',\n');
|
||||||
}
|
}
|
||||||
|
|
||||||
private getOutputColumnNames(columns: PredictColumn[], tableName: string) {
|
|
||||||
return columns.map(c => {
|
|
||||||
return this.getColumnName(tableName, c.columnName, c.paramName || '');
|
|
||||||
}).join(',\n');
|
|
||||||
}
|
|
||||||
|
|
||||||
private getColumnName(tableName: string, columnName: string, displayName: string) {
|
private getColumnName(tableName: string, columnName: string, displayName: string) {
|
||||||
const column = this.getEscapedColumnName(tableName, columnName);
|
const column = this.getEscapedColumnName(tableName, columnName);
|
||||||
return columnName && columnName !== displayName ?
|
return columnName && columnName !== displayName ?
|
||||||
|
|||||||
110
extensions/machine-learning/src/test/common/config.test.ts
Normal file
110
extensions/machine-learning/src/test/common/config.test.ts
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as vscode from 'vscode';
|
||||||
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
|
import * as TypeMoq from 'typemoq';
|
||||||
|
import * as should from 'should';
|
||||||
|
import { Config } from '../../configurations/config';
|
||||||
|
import * as utils from '../../common/utils';
|
||||||
|
import * as path from 'path';
|
||||||
|
|
||||||
|
interface TestContext {
|
||||||
|
|
||||||
|
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||||
|
}
|
||||||
|
|
||||||
|
function createContext(): TestContext {
|
||||||
|
return {
|
||||||
|
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let configData : vscode.WorkspaceConfiguration = {
|
||||||
|
get: () => {},
|
||||||
|
has: () => true,
|
||||||
|
inspect: () => undefined,
|
||||||
|
update: () => {return Promise.resolve();},
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
describe('Config', () => {
|
||||||
|
it('getPythonExecutable should default to ADS python location is not configured', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return ''; };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
const expected = utils.getDefaultPythonLocation();
|
||||||
|
const actual = await config.getPythonExecutable(false);
|
||||||
|
should.deepEqual(actual, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getPythonExecutable should add python executable name is folder path is configured', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return utils.getUserHome(); };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
const expected = path.join(utils.getUserHome() || '', utils.getPythonExeName());
|
||||||
|
const actual = await config.getPythonExecutable(false);
|
||||||
|
should.deepEqual(actual, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getPythonExecutable should not add python executable if already added', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return path.join(utils.getUserHome() || '', utils.getPythonExeName()); };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
const expected = path.join(utils.getUserHome() || '', utils.getPythonExeName());
|
||||||
|
const actual = await config.getPythonExecutable(false);
|
||||||
|
should.deepEqual(actual, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getPythonExecutable should not add python executable set to python', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return 'python'; };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
const expected = 'python';
|
||||||
|
const actual = await config.getPythonExecutable(false);
|
||||||
|
should.deepEqual(actual, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getPythonExecutable should not add python executable set to python3', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return 'python3'; };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
const expected = 'python3';
|
||||||
|
const actual = await config.getPythonExecutable(false);
|
||||||
|
should.deepEqual(actual, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getRExecutable should not add r executable set to r', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return 'r'; };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
const expected = 'r';
|
||||||
|
const actual = await config.getRExecutable(false);
|
||||||
|
should.deepEqual(actual, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getPythonExecutable should throw error if file does not exist', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return path.join(utils.getUserHome() || '', 'invalidPath'); };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
await should(config.getPythonExecutable(true)).be.rejected();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('getRExecutable should throw error if file does not exist', async function (): Promise<void> {
|
||||||
|
const context = createContext();
|
||||||
|
configData.get = () => { return path.join(utils.getUserHome() || '', 'invalidPath'); };
|
||||||
|
context.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => configData);
|
||||||
|
let config = new Config('', context.apiWrapper.object);
|
||||||
|
await should(config.getRExecutable(true)).be.rejected();
|
||||||
|
});
|
||||||
|
|
||||||
|
});
|
||||||
@@ -53,7 +53,7 @@ describe('ModelPythonClient', () => {
|
|||||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||||
operationInfo.operation(testContext.op);
|
operationInfo.operation(testContext.op);
|
||||||
});
|
});
|
||||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('pythonPath'));
|
||||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ describe('ModelPythonClient', () => {
|
|||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.packageManager.object);
|
testContext.packageManager.object);
|
||||||
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('pythonPath'));
|
||||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
|
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
|
||||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||||
|
|||||||
@@ -254,8 +254,8 @@ describe('Package Manager', () => {
|
|||||||
{ name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'}
|
{ name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'}
|
||||||
]);
|
]);
|
||||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
|
testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('python'));
|
||||||
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
|
testContext.config.setup(x => x.getRExecutable(true)).returns(() => Promise.resolve('r'));
|
||||||
testContext.config.setup(x => x.rEnabled).returns(() => true);
|
testContext.config.setup(x => x.rEnabled).returns(() => true);
|
||||||
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
|
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
|
||||||
let packageManager = new PackageManager(
|
let packageManager = new PackageManager(
|
||||||
|
|||||||
@@ -386,7 +386,7 @@ describe('SQL Python Package Manager', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
function createProvider(testContext: TestContext): SqlPythonPackageManageProvider {
|
function createProvider(testContext: TestContext): SqlPythonPackageManageProvider {
|
||||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
|
testContext.config.setup(x => x.getPythonExecutable(true)).returns(() => Promise.resolve('python'));
|
||||||
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
|
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
|
||||||
return new SqlPythonPackageManageProvider(
|
return new SqlPythonPackageManageProvider(
|
||||||
testContext.outputChannel,
|
testContext.outputChannel,
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ describe('SQL R Package Manager', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
function createProvider(testContext: TestContext): SqlRPackageManageProvider {
|
function createProvider(testContext: TestContext): SqlRPackageManageProvider {
|
||||||
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
|
testContext.config.setup(x => x.getRExecutable(true)).returns(() => Promise.resolve('r'));
|
||||||
testContext.config.setup(x => x.rEnabled).returns(() => true);
|
testContext.config.setup(x => x.rEnabled).returns(() => true);
|
||||||
testContext.config.setup(x => x.rPackagesRepository).returns(() => 'http://cran.r-project.org');
|
testContext.config.setup(x => x.rPackagesRepository).returns(() => 'http://cran.r-project.org');
|
||||||
return new SqlRPackageManageProvider(
|
return new SqlRPackageManageProvider(
|
||||||
|
|||||||
@@ -39,10 +39,12 @@ describe('Dashboard widget', () => {
|
|||||||
await handler(testContext.view);
|
await handler(testContext.view);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
testContext.apiWrapper.setup(x => x.openExternal(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||||
|
|
||||||
testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(true));
|
testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(true));
|
||||||
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '', testContext.predictService.object);
|
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '', testContext.predictService.object);
|
||||||
await dashboard.register();
|
await dashboard.register();
|
||||||
testContext.onClick.fire(undefined);
|
testContext.onClick.fire(undefined);
|
||||||
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce());
|
testContext.apiWrapper.verify(x => x.openExternal(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -17,10 +17,13 @@ import { ImportedModel } from '../../../modelManagement/interfaces';
|
|||||||
* View to render current registered models
|
* View to render current registered models
|
||||||
*/
|
*/
|
||||||
export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
||||||
private _tableComponent: azdata.Component | undefined;
|
|
||||||
private _dataTable: CurrentModelsTable | undefined;
|
private _dataTable: CurrentModelsTable | undefined;
|
||||||
private _loader: azdata.LoadingComponent | undefined;
|
private _loader: azdata.LoadingComponent | undefined;
|
||||||
private _tableSelectionComponent: TableSelectionComponent | undefined;
|
private _tableSelectionComponent: TableSelectionComponent | undefined;
|
||||||
|
private _labelComponent: azdata.TextComponent | undefined;
|
||||||
|
private _descriptionComponent: azdata.TextComponent | undefined;
|
||||||
|
private _labelContainer: azdata.FlexContainer | undefined;
|
||||||
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
@@ -43,37 +46,69 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
|||||||
});
|
});
|
||||||
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings);
|
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings);
|
||||||
this._dataTable.registerComponent(modelBuilder);
|
this._dataTable.registerComponent(modelBuilder);
|
||||||
this._tableComponent = this._dataTable.component;
|
|
||||||
|
|
||||||
let formModelBuilder = modelBuilder.formContainer();
|
let formModelBuilder = modelBuilder.formContainer();
|
||||||
this._tableSelectionComponent.addComponents(formModelBuilder);
|
|
||||||
|
|
||||||
if (this._tableComponent) {
|
|
||||||
formModelBuilder.addFormItem({
|
|
||||||
component: this._tableComponent,
|
|
||||||
title: ''
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
this._loader = modelBuilder.loadingComponent()
|
this._loader = modelBuilder.loadingComponent()
|
||||||
.withItem(formModelBuilder.component())
|
.withItem(formModelBuilder.component())
|
||||||
.withProperties({
|
.withProperties({
|
||||||
loading: true
|
loading: true
|
||||||
}).component();
|
}).component();
|
||||||
|
this._labelComponent = modelBuilder.text().withProperties({
|
||||||
|
width: 200,
|
||||||
|
value: constants.modelsListEmptyMessage
|
||||||
|
}).component();
|
||||||
|
this._descriptionComponent = modelBuilder.text().withProperties({
|
||||||
|
width: 200,
|
||||||
|
value: constants.modelsListEmptyDescription
|
||||||
|
}).component();
|
||||||
|
this._labelContainer = modelBuilder.flexContainer().withLayout({
|
||||||
|
flexFlow: 'column',
|
||||||
|
width: 800,
|
||||||
|
height: '400px',
|
||||||
|
justifyContent: 'center'
|
||||||
|
}).component();
|
||||||
|
|
||||||
|
this._labelContainer.addItem(
|
||||||
|
this._labelComponent
|
||||||
|
, {
|
||||||
|
CSSStyles: {
|
||||||
|
'align-items': 'center',
|
||||||
|
'padding-top': '30px',
|
||||||
|
'padding-left': `${this.componentMaxLength}px`,
|
||||||
|
'font-size': '16px'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
this._labelContainer.addItem(
|
||||||
|
this._descriptionComponent
|
||||||
|
, {
|
||||||
|
CSSStyles: {
|
||||||
|
'align-items': 'center',
|
||||||
|
'padding-top': '10px',
|
||||||
|
'padding-left': `${this.componentMaxLength - 50}px`,
|
||||||
|
'font-size': '13px'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
this.addComponents(formModelBuilder);
|
||||||
return this._loader;
|
return this._loader;
|
||||||
}
|
}
|
||||||
|
|
||||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._tableSelectionComponent && this._dataTable) {
|
this._formBuilder = formBuilder;
|
||||||
|
if (this._tableSelectionComponent && this._dataTable && this._labelContainer) {
|
||||||
this._tableSelectionComponent.addComponents(formBuilder);
|
this._tableSelectionComponent.addComponents(formBuilder);
|
||||||
this._dataTable.addComponents(formBuilder);
|
this._dataTable.addComponents(formBuilder);
|
||||||
|
if (this._dataTable.isEmpty) {
|
||||||
|
formBuilder.addFormItem({ title: '', component: this._labelContainer });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._tableSelectionComponent && this._dataTable) {
|
if (this._tableSelectionComponent && this._dataTable && this._labelContainer) {
|
||||||
this._tableSelectionComponent.removeComponents(formBuilder);
|
this._tableSelectionComponent.removeComponents(formBuilder);
|
||||||
this._dataTable.removeComponents(formBuilder);
|
this._dataTable.removeComponents(formBuilder);
|
||||||
|
formBuilder.removeFormItem({ title: '', component: this._labelContainer });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,10 +126,11 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
|||||||
await this.onLoading();
|
await this.onLoading();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (this._tableSelectionComponent) {
|
if (this._tableSelectionComponent && this._dataTable) {
|
||||||
this._tableSelectionComponent.refresh();
|
await this._tableSelectionComponent.refresh();
|
||||||
|
await this._dataTable.refresh();
|
||||||
|
this.refreshComponents();
|
||||||
}
|
}
|
||||||
await this._dataTable?.refresh();
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.showErrorMessage(constants.getErrorMessage(err));
|
this.showErrorMessage(constants.getErrorMessage(err));
|
||||||
} finally {
|
} finally {
|
||||||
@@ -106,6 +142,13 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
|||||||
return this._dataTable?.data;
|
return this._dataTable?.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private refreshComponents(): void {
|
||||||
|
if (this._formBuilder) {
|
||||||
|
this.removeComponents(this._formBuilder);
|
||||||
|
this.addComponents(this._formBuilder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private async onTableSelected(): Promise<void> {
|
private async onTableSelected(): Promise<void> {
|
||||||
if (this._tableSelectionComponent?.data) {
|
if (this._tableSelectionComponent?.data) {
|
||||||
this.importTable = this._tableSelectionComponent?.data;
|
this.importTable = this._tableSelectionComponent?.data;
|
||||||
@@ -113,6 +156,7 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
|||||||
if (this._dataTable) {
|
if (this._dataTable) {
|
||||||
await this._dataTable.refresh();
|
await this._dataTable.refresh();
|
||||||
}
|
}
|
||||||
|
this.refreshComponents();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
private _downloadedFile: ModelArtifact | undefined;
|
private _downloadedFile: ModelArtifact | undefined;
|
||||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||||
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
||||||
|
public isEmpty: boolean = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates new view
|
* Creates new view
|
||||||
@@ -149,7 +150,6 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the component
|
* Returns the component
|
||||||
*/
|
*/
|
||||||
@@ -176,6 +176,8 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
|
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
this.isEmpty = models === undefined || models.length === 0;
|
||||||
|
|
||||||
this._table.data = tableData;
|
this._table.data = tableData;
|
||||||
}
|
}
|
||||||
this.onModelSelected();
|
this.onModelSelected();
|
||||||
@@ -275,7 +277,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
if (confirm) {
|
if (confirm) {
|
||||||
await this.sendDataRequest(DeleteModelEventName, model);
|
await this.sendDataRequest(DeleteModelEventName, model);
|
||||||
if (this.parent) {
|
if (this.parent) {
|
||||||
await this.parent?.refresh();
|
await this.parent.refresh();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@@ -108,11 +108,12 @@ export class ImportModelWizard extends ModelViewBase {
|
|||||||
} else {
|
} else {
|
||||||
await this.importAzureModel(this.modelsViewData);
|
await this.importAzureModel(this.modelsViewData);
|
||||||
}
|
}
|
||||||
|
this._apiWrapper.showInfoMessage(constants.modelRegisteredSuccessfully);
|
||||||
await this.storeImportConfigTable();
|
await this.storeImportConfigTable();
|
||||||
this.showInfoMessage(constants.modelRegisteredSuccessfully);
|
|
||||||
return true;
|
return true;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
|
await this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView,
|
|||||||
private _formBuilder: azdata.FormBuilder | undefined;
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
public tableSelectionComponent: TableSelectionComponent | undefined;
|
public tableSelectionComponent: TableSelectionComponent | undefined;
|
||||||
private _labelComponent: azdata.TextComponent | undefined;
|
private _labelComponent: azdata.TextComponent | undefined;
|
||||||
|
private _descriptionComponent: azdata.TextComponent | undefined;
|
||||||
|
private _labelContainer: azdata.FlexContainer | undefined;
|
||||||
|
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
@@ -33,23 +36,40 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView,
|
|||||||
|
|
||||||
this._formBuilder = modelBuilder.formContainer();
|
this._formBuilder = modelBuilder.formContainer();
|
||||||
this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, { editable: true, preSelected: true });
|
this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, { editable: true, preSelected: true });
|
||||||
|
this._descriptionComponent = modelBuilder.text().withProperties({
|
||||||
|
width: 200
|
||||||
|
}).component();
|
||||||
this._labelComponent = modelBuilder.text().withProperties({
|
this._labelComponent = modelBuilder.text().withProperties({
|
||||||
width: 200
|
width: 200
|
||||||
}).component();
|
}).component();
|
||||||
const container = modelBuilder.flexContainer().withLayout({
|
this._labelContainer = modelBuilder.flexContainer().withLayout({
|
||||||
|
flexFlow: 'column',
|
||||||
width: 800,
|
width: 800,
|
||||||
height: '400px',
|
height: '300px',
|
||||||
justifyContent: 'center'
|
justifyContent: 'center'
|
||||||
}).withItems([
|
|
||||||
this._labelComponent
|
|
||||||
], {
|
|
||||||
CSSStyles: {
|
|
||||||
'align-items': 'center',
|
|
||||||
'padding-top': '30px',
|
|
||||||
'font-size': '16px'
|
|
||||||
}
|
|
||||||
}).component();
|
}).component();
|
||||||
|
|
||||||
|
this._labelContainer.addItem(
|
||||||
|
this._labelComponent
|
||||||
|
, {
|
||||||
|
CSSStyles: {
|
||||||
|
'align-items': 'center',
|
||||||
|
'padding-top': '10px',
|
||||||
|
'padding-left': `${this.componentMaxLength}px`,
|
||||||
|
'font-size': '16px'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
this._labelContainer.addItem(
|
||||||
|
this._descriptionComponent
|
||||||
|
, {
|
||||||
|
CSSStyles: {
|
||||||
|
'align-items': 'center',
|
||||||
|
'padding-top': '10px',
|
||||||
|
'padding-left': `${this.componentMaxLength - 80}px`,
|
||||||
|
'font-size': '13px'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
this.tableSelectionComponent.onSelectedChanged(async () => {
|
this.tableSelectionComponent.onSelectedChanged(async () => {
|
||||||
await this.onTableSelected();
|
await this.onTableSelected();
|
||||||
@@ -59,7 +79,7 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView,
|
|||||||
|
|
||||||
this._formBuilder.addFormItem({
|
this._formBuilder.addFormItem({
|
||||||
title: '',
|
title: '',
|
||||||
component: container
|
component: this._labelContainer
|
||||||
});
|
});
|
||||||
this._form = this._formBuilder.component();
|
this._form = this._formBuilder.component();
|
||||||
return this._form;
|
return this._form;
|
||||||
@@ -71,15 +91,24 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (this.importTable && this._labelComponent) {
|
if (this.importTable && this._labelComponent) {
|
||||||
const validated = await this.verifyImportConfigTable(this.importTable);
|
if (!this.validateImportTableName()) {
|
||||||
if (validated) {
|
this._labelComponent.value = constants.selectModelsTableMessage;
|
||||||
this._labelComponent.value = constants.modelSchemaIsAcceptedMessage;
|
|
||||||
} else {
|
} else {
|
||||||
this._labelComponent.value = constants.modelSchemaIsNotAcceptedMessage;
|
const validated = await this.verifyImportConfigTable(this.importTable);
|
||||||
|
if (validated) {
|
||||||
|
this._labelComponent.value = constants.modelSchemaIsAcceptedMessage;
|
||||||
|
} else {
|
||||||
|
this._labelComponent.value = constants.modelSchemaIsNotAcceptedMessage;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private validateImportTableName(): boolean {
|
||||||
|
return this.importTable?.databaseName !== undefined && this.importTable?.databaseName !== constants.selectDatabaseTitle
|
||||||
|
&& this.importTable?.tableName !== undefined && this.importTable?.tableName !== constants.selectTableTitle;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
@@ -116,7 +145,7 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView,
|
|||||||
public async validate(): Promise<boolean> {
|
public async validate(): Promise<boolean> {
|
||||||
let validated = false;
|
let validated = false;
|
||||||
|
|
||||||
if (this.data?.databaseName && this.data?.tableName) {
|
if (this.data && this.validateImportTableName()) {
|
||||||
validated = true;
|
validated = true;
|
||||||
validated = await this.verifyImportConfigTable(this.data);
|
validated = await this.verifyImportConfigTable(this.data);
|
||||||
if (!validated) {
|
if (!validated) {
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
|
|||||||
width: this.componentMaxLength
|
width: this.componentMaxLength
|
||||||
}).component();
|
}).component();
|
||||||
const name = modelParameter.name;
|
const name = modelParameter.name;
|
||||||
let column = values.find(x => x.name === modelParameter.name);
|
let column = values.find(x => x.name.toLocaleUpperCase() === modelParameter.name.toLocaleUpperCase());
|
||||||
if (!column) {
|
if (!column) {
|
||||||
column = values.length > 0 ? values[0] : undefined;
|
column = values.length > 0 ? values[0] : undefined;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,11 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo
|
|||||||
private _dbTableComponent: azdata.FlexContainer | undefined;
|
private _dbTableComponent: azdata.FlexContainer | undefined;
|
||||||
private tableMaxLength = this.componentMaxLength * 2 + 70;
|
private tableMaxLength = this.componentMaxLength * 2 + 70;
|
||||||
private _onSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
private _onSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||||
|
private _existingTableButton: azdata.RadioButtonComponent | undefined;
|
||||||
|
private _newTableButton: azdata.RadioButtonComponent | undefined;
|
||||||
|
private _newTableName: azdata.InputBoxComponent | undefined;
|
||||||
private _existingTablesSelected: boolean = true;
|
private _existingTablesSelected: boolean = true;
|
||||||
|
|
||||||
public readonly onSelectedChanged: vscode.Event<void> = this._onSelectedChanged.event;
|
public readonly onSelectedChanged: vscode.Event<void> = this._onSelectedChanged.event;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -55,50 +59,46 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo
|
|||||||
await this.onDatabaseSelected();
|
await this.onDatabaseSelected();
|
||||||
});
|
});
|
||||||
|
|
||||||
const existingTableButton = modelBuilder.radioButton().withProperties({
|
this._existingTableButton = modelBuilder.radioButton().withProperties({
|
||||||
name: 'tableName',
|
name: 'tableName',
|
||||||
value: 'existing',
|
value: 'existing',
|
||||||
label: 'Existing table',
|
label: 'Existing table',
|
||||||
checked: true
|
checked: true
|
||||||
}).component();
|
}).component();
|
||||||
const newTableButton = modelBuilder.radioButton().withProperties({
|
this._newTableButton = modelBuilder.radioButton().withProperties({
|
||||||
name: 'tableName',
|
name: 'tableName',
|
||||||
value: 'new',
|
value: 'new',
|
||||||
label: 'New table',
|
label: 'New table',
|
||||||
checked: false
|
checked: false
|
||||||
}).component();
|
}).component();
|
||||||
const newTableName = modelBuilder.inputBox().withProperties({
|
this._newTableName = modelBuilder.inputBox().withProperties({
|
||||||
width: this.componentMaxLength - 10,
|
width: this.componentMaxLength - 10,
|
||||||
enabled: false
|
enabled: false
|
||||||
}).component();
|
}).component();
|
||||||
const group = modelBuilder.groupContainer().withItems([
|
const group = modelBuilder.groupContainer().withItems([
|
||||||
existingTableButton,
|
this._existingTableButton,
|
||||||
this._tables,
|
this._tables,
|
||||||
newTableButton,
|
this._newTableButton,
|
||||||
newTableName
|
this._newTableName
|
||||||
], {
|
], {
|
||||||
CSSStyles: {
|
CSSStyles: {
|
||||||
'padding-top': '5px'
|
'padding-top': '5px'
|
||||||
}
|
}
|
||||||
}).component();
|
}).component();
|
||||||
|
|
||||||
existingTableButton.onDidClick(() => {
|
this._existingTableButton.onDidClick(() => {
|
||||||
if (this._tables) {
|
this._existingTablesSelected = true;
|
||||||
this._tables.enabled = existingTableButton.checked;
|
this.refreshTableComponent();
|
||||||
}
|
|
||||||
newTableName.enabled = !existingTableButton.checked;
|
|
||||||
this._existingTablesSelected = existingTableButton.checked || false;
|
|
||||||
});
|
});
|
||||||
newTableButton.onDidClick(() => {
|
this._newTableButton.onDidClick(() => {
|
||||||
if (this._tables) {
|
this._existingTablesSelected = false;
|
||||||
this._tables.enabled = !newTableButton.checked;
|
this.refreshTableComponent();
|
||||||
}
|
|
||||||
newTableName.enabled = newTableButton.checked;
|
|
||||||
this._existingTablesSelected = existingTableButton.checked || false;
|
|
||||||
});
|
});
|
||||||
newTableName.onTextChanged(async () => {
|
this._newTableName.onTextChanged(async () => {
|
||||||
this._selectedTableName = newTableName.value || '';
|
if (this._newTableName) {
|
||||||
await this.onTableSelected();
|
this._selectedTableName = this._newTableName.value || '';
|
||||||
|
await this.onTableSelected();
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
this._tables.onValueChanged(async (value) => {
|
this._tables.onValueChanged(async (value) => {
|
||||||
@@ -192,7 +192,7 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo
|
|||||||
public async loadData(): Promise<void> {
|
public async loadData(): Promise<void> {
|
||||||
this._dbNames = await this.listDatabaseNames();
|
this._dbNames = await this.listDatabaseNames();
|
||||||
let dbNames = this._dbNames;
|
let dbNames = this._dbNames;
|
||||||
if (!this._settings.preSelected && !this._dbNames.find(x => x === constants.selectDatabaseTitle)) {
|
if (!this._dbNames.find(x => x === constants.selectDatabaseTitle)) {
|
||||||
dbNames = [constants.selectDatabaseTitle].concat(this._dbNames);
|
dbNames = [constants.selectDatabaseTitle].concat(this._dbNames);
|
||||||
}
|
}
|
||||||
if (this._databases && dbNames && dbNames.length > 0) {
|
if (this._databases && dbNames && dbNames.length > 0) {
|
||||||
@@ -216,35 +216,49 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async onDatabaseSelected(): Promise<void> {
|
private async onDatabaseSelected(): Promise<void> {
|
||||||
if (this._existingTablesSelected) {
|
this._tableNames = await this.listTableNames(this.databaseName || '');
|
||||||
this._tableNames = await this.listTableNames(this.databaseName || '');
|
let tableNames = this._tableNames;
|
||||||
let tableNames = this._tableNames;
|
if (this._settings.editable && this._tables && this._existingTableButton && this._newTableButton && this._newTableName) {
|
||||||
|
this._existingTablesSelected = this._tableNames !== undefined && this._tableNames.length > 0;
|
||||||
if (this._tableNames && !this._settings.preSelected && !this._tableNames.find(x => x.tableName === constants.selectTableTitle)) {
|
this._newTableButton.checked = !this._existingTablesSelected;
|
||||||
const firstRow: DatabaseTable = { tableName: constants.selectTableTitle, databaseName: '', schema: '' };
|
this._existingTableButton.checked = this._existingTablesSelected;
|
||||||
tableNames = [firstRow].concat(this._tableNames);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this._tables && tableNames && tableNames.length > 0) {
|
|
||||||
this._tables.values = tableNames.map(t => this.getTableFullName(t));
|
|
||||||
if (this.importTable) {
|
|
||||||
const selectedTable = tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema);
|
|
||||||
if (selectedTable) {
|
|
||||||
this._selectedTableName = this.getTableFullName(selectedTable);
|
|
||||||
this._tables.value = this.getTableFullName(selectedTable);
|
|
||||||
} else {
|
|
||||||
this._selectedTableName = this._settings.editable ? this.getTableFullName(this.importTable) : this.getTableFullName(tableNames[0]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
this._selectedTableName = this.getTableFullName(tableNames[0]);
|
|
||||||
}
|
|
||||||
this._tables.value = this._selectedTableName;
|
|
||||||
} else if (this._tables) {
|
|
||||||
this._tables.values = [];
|
|
||||||
this._tables.value = '';
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
this.refreshTableComponent();
|
||||||
|
|
||||||
|
|
||||||
|
if (this._tableNames && !this._tableNames.find(x => x.tableName === constants.selectTableTitle)) {
|
||||||
|
const firstRow: DatabaseTable = { tableName: constants.selectTableTitle, databaseName: '', schema: '' };
|
||||||
|
tableNames = [firstRow].concat(this._tableNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this._tables && tableNames && tableNames.length > 0) {
|
||||||
|
this._tables.values = tableNames.map(t => this.getTableFullName(t));
|
||||||
|
if (this.importTable && this.importTable.databaseName === this._databases?.value) {
|
||||||
|
const selectedTable = tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema);
|
||||||
|
if (selectedTable) {
|
||||||
|
this._selectedTableName = this.getTableFullName(selectedTable);
|
||||||
|
this._tables.value = this.getTableFullName(selectedTable);
|
||||||
|
} else {
|
||||||
|
this._selectedTableName = this._settings.editable ? this.getTableFullName(this.importTable) : this.getTableFullName(tableNames[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this._selectedTableName = this.getTableFullName(tableNames[0]);
|
||||||
|
}
|
||||||
|
this._tables.value = this._selectedTableName;
|
||||||
|
} else if (this._tables) {
|
||||||
|
this._tables.values = [];
|
||||||
|
this._tables.value = '';
|
||||||
|
}
|
||||||
|
|
||||||
await this.onTableSelected();
|
await this.onTableSelected();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private refreshTableComponent(): void {
|
||||||
|
if (this._settings.editable && this._tables && this._existingTableButton && this._newTableButton && this._newTableName) {
|
||||||
|
this._tables.enabled = this._existingTablesSelected;
|
||||||
|
this._newTableName.enabled = !this._existingTablesSelected;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private getTableFullName(table: DatabaseTable): string {
|
private getTableFullName(table: DatabaseTable): string {
|
||||||
|
|||||||
@@ -492,7 +492,9 @@ export class DashboardWidget {
|
|||||||
'padding': '10px'
|
'padding': '10px'
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
predictionButton.enabled = await this._predictService.serverSupportOnnxModel();
|
if (!await this._predictService.serverSupportOnnxModel()) {
|
||||||
|
console.log(constants.onnxNotSupportedError);
|
||||||
|
}
|
||||||
|
|
||||||
return tasksContainer;
|
return tasksContainer;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user