mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
ML - Target import table selectable by user (#10071)
ML - Target import table selectable by user
This commit is contained in:
@@ -115,6 +115,8 @@ export const extLangInstallFailedError = localize('extLang.installFailedError',
|
|||||||
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
|
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
|
||||||
|
|
||||||
export const modelArtifactName = localize('models.artifactName', "Artifact Name");
|
export const modelArtifactName = localize('models.artifactName', "Artifact Name");
|
||||||
|
export const databaseName = localize('databaseName', "Database name");
|
||||||
|
export const tableName = localize('tableName', "Table name");
|
||||||
export const modelName = localize('models.name', "Name");
|
export const modelName = localize('models.name', "Name");
|
||||||
export const modelFileName = localize('models.fileName', "File");
|
export const modelFileName = localize('models.fileName', "File");
|
||||||
export const modelDescription = localize('models.description', "Description");
|
export const modelDescription = localize('models.description', "Description");
|
||||||
@@ -140,13 +142,14 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model
|
|||||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?");
|
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?");
|
||||||
|
export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Where do you want import models to?");
|
||||||
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
|
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
|
||||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details");
|
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details");
|
||||||
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
||||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||||
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
|
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
|
||||||
export const predictModel = localize('models.predictModel', "Predict");
|
export const predictModel = localize('models.predictModel', "Predict");
|
||||||
export const registerModelTitle = localize('models.RegisterWizard', "Deployed models");
|
export const registerModelTitle = localize('models.RegisterWizard', "Import models");
|
||||||
export const importModelTitle = localize('models.importModelTitle', "Import models");
|
export const importModelTitle = localize('models.importModelTitle', "Import models");
|
||||||
export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model");
|
export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model");
|
||||||
export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions");
|
export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions");
|
||||||
@@ -163,9 +166,12 @@ export const invalidAzureResourceError = localize('models.invalidAzureResourceEr
|
|||||||
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
||||||
export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict");
|
export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict");
|
||||||
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
||||||
|
export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table");
|
||||||
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||||
export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); }
|
export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); }
|
||||||
|
export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
|
||||||
|
export function invalidImportTableSchemaError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableSchemaError', "Table schema is not supported for model import. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
|
||||||
|
|
||||||
export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
|
export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
|
||||||
export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
|
export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
|
||||||
|
|||||||
@@ -183,10 +183,31 @@ export class QueryRunner {
|
|||||||
try {
|
try {
|
||||||
return await this.runQuery(connection, query);
|
return await this.runQuery(connection, query);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log(error);
|
//console.log(error);
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes the query but doesn't fail it is fails
|
||||||
|
* @param connection SQL connection
|
||||||
|
* @param query query to run
|
||||||
|
*/
|
||||||
|
public async runWithDatabaseChange(connection: azdata.connection.ConnectionProfile, query: string, queryDb: string): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||||
|
if (connection) {
|
||||||
|
try {
|
||||||
|
return await this.runQuery(connection, `
|
||||||
|
USE [${utils.doubleEscapeSingleBrackets(queryDb)}]
|
||||||
|
${query}`);
|
||||||
|
} catch (error) {
|
||||||
|
console.log(error);
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
this.safeRunQuery(connection, `USE [${utils.doubleEscapeSingleBrackets(connection.databaseName || 'master')}]`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import * as fs from 'fs';
|
|||||||
import * as constants from '../common/constants';
|
import * as constants from '../common/constants';
|
||||||
import { promisify } from 'util';
|
import { promisify } from 'util';
|
||||||
import { ApiWrapper } from './apiWrapper';
|
import { ApiWrapper } from './apiWrapper';
|
||||||
import { Config } from '../configurations/config';
|
|
||||||
|
|
||||||
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
|
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
|
||||||
let tempFilePath: string = '';
|
let tempFilePath: string = '';
|
||||||
@@ -221,21 +220,21 @@ export function getScriptWithDBChange(currentDb: string, databaseName: string, s
|
|||||||
* Returns full name of model registration table
|
* Returns full name of model registration table
|
||||||
* @param config config
|
* @param config config
|
||||||
*/
|
*/
|
||||||
export function getRegisteredModelsThreePartsName(config: Config) {
|
export function getRegisteredModelsThreePartsName(db: string, table: string, schema: string) {
|
||||||
const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName);
|
const dbName = doubleEscapeSingleBrackets(db);
|
||||||
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
|
const schemaName = doubleEscapeSingleBrackets(schema);
|
||||||
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
const tableName = doubleEscapeSingleBrackets(table);
|
||||||
return `[${dbName}].[${schema}].[${tableName}]`;
|
return `[${dbName}].[${schemaName}].[${tableName}]`;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns full name of model registration table
|
* Returns full name of model registration table
|
||||||
* @param config config object
|
* @param config config object
|
||||||
*/
|
*/
|
||||||
export function getRegisteredModelsTowPartsName(config: Config) {
|
export function getRegisteredModelsTwoPartsName(table: string, schema: string) {
|
||||||
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
|
const schemaName = doubleEscapeSingleBrackets(schema);
|
||||||
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
const tableName = doubleEscapeSingleBrackets(table);
|
||||||
return `[${schema}].[${tableName}]`;
|
return `[${schemaName}].[${tableName}]`;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import { AzureModelRegistryService } from '../modelManagement/azureModelRegistry
|
|||||||
import { ModelPythonClient } from '../modelManagement/modelPythonClient';
|
import { ModelPythonClient } from '../modelManagement/modelPythonClient';
|
||||||
import { PredictService } from '../prediction/predictService';
|
import { PredictService } from '../prediction/predictService';
|
||||||
import { DashboardWidget } from '../views/widgets/dashboardWidget';
|
import { DashboardWidget } from '../views/widgets/dashboardWidget';
|
||||||
|
import { ModelConfigRecent } from '../modelManagement/modelConfigRecent';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The main controller class that initializes the extension
|
* The main controller class that initializes the extension
|
||||||
@@ -102,12 +103,13 @@ export default class MainController implements vscode.Disposable {
|
|||||||
let languagesModel = new LanguageService(this._apiWrapper, mssqlService);
|
let languagesModel = new LanguageService(this._apiWrapper, mssqlService);
|
||||||
let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel);
|
let languageController = new LanguageController(this._apiWrapper, this._rootPath, languagesModel);
|
||||||
let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
|
let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
|
||||||
|
let modelRecentService = new ModelConfigRecent(this._context.globalState);
|
||||||
|
|
||||||
// Model Management
|
// Model Management
|
||||||
//
|
//
|
||||||
let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
|
let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter, modelRecentService);
|
||||||
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
|
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
|
||||||
let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config);
|
let predictService = new PredictService(this._apiWrapper, this._queryRunner);
|
||||||
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
|
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
|
||||||
azureModelsService, registeredModelService, predictService);
|
azureModelsService, registeredModelService, predictService);
|
||||||
|
|
||||||
@@ -121,7 +123,7 @@ export default class MainController implements vscode.Disposable {
|
|||||||
await modelManagementController.manageRegisteredModels();
|
await modelManagementController.manageRegisteredModels();
|
||||||
}));
|
}));
|
||||||
this._apiWrapper.registerCommand(constants.mlImportModelCommand, (async () => {
|
this._apiWrapper.registerCommand(constants.mlImportModelCommand, (async () => {
|
||||||
await modelManagementController.registerModel();
|
await modelManagementController.registerModel(undefined);
|
||||||
}));
|
}));
|
||||||
this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => {
|
this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => {
|
||||||
await modelManagementController.predictModel();
|
await modelManagementController.predictModel();
|
||||||
@@ -135,15 +137,6 @@ export default class MainController implements vscode.Disposable {
|
|||||||
this._apiWrapper.registerTaskHandler(constants.mlManageLanguagesCommand, async () => {
|
this._apiWrapper.registerTaskHandler(constants.mlManageLanguagesCommand, async () => {
|
||||||
await languageController.manageLanguages();
|
await languageController.manageLanguages();
|
||||||
});
|
});
|
||||||
this._apiWrapper.registerTaskHandler(constants.mlManageModelsCommand, async () => {
|
|
||||||
await modelManagementController.manageRegisteredModels();
|
|
||||||
});
|
|
||||||
this._apiWrapper.registerTaskHandler(constants.mlImportModelCommand, async () => {
|
|
||||||
await modelManagementController.registerModel();
|
|
||||||
});
|
|
||||||
this._apiWrapper.registerTaskHandler(constants.mlsPredictModelCommand, async () => {
|
|
||||||
await modelManagementController.predictModel();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import { QueryRunner } from '../common/queryRunner';
|
|||||||
import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces';
|
import { RegisteredModel, RegisteredModelDetails, ModelParameters } from './interfaces';
|
||||||
import { ModelPythonClient } from './modelPythonClient';
|
import { ModelPythonClient } from './modelPythonClient';
|
||||||
import * as constants from '../common/constants';
|
import * as constants from '../common/constants';
|
||||||
|
import { DatabaseTable } from '../prediction/interfaces';
|
||||||
|
import { ModelConfigRecent } from './modelConfigRecent';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Service to deployed models
|
* Service to deployed models
|
||||||
@@ -25,23 +27,25 @@ export class DeployedModelService {
|
|||||||
private _apiWrapper: ApiWrapper,
|
private _apiWrapper: ApiWrapper,
|
||||||
private _config: Config,
|
private _config: Config,
|
||||||
private _queryRunner: QueryRunner,
|
private _queryRunner: QueryRunner,
|
||||||
private _modelClient: ModelPythonClient) {
|
private _modelClient: ModelPythonClient,
|
||||||
|
private _recentModelService: ModelConfigRecent) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns deployed models
|
* Returns deployed models
|
||||||
*/
|
*/
|
||||||
public async getDeployedModels(): Promise<RegisteredModel[]> {
|
public async getDeployedModels(table: DatabaseTable): Promise<RegisteredModel[]> {
|
||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
let list: RegisteredModel[] = [];
|
let list: RegisteredModel[] = [];
|
||||||
|
if (!table.databaseName || !table.tableName || !table.schema) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
if (connection) {
|
if (connection) {
|
||||||
let query = this.getConfigureQuery(connection.databaseName);
|
const query = this.getDeployedModelsQuery(table);
|
||||||
await this._queryRunner.safeRunQuery(connection, query);
|
|
||||||
query = this.getDeployedModelsQuery();
|
|
||||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||||
if (result && result.rows && result.rows.length > 0) {
|
if (result && result.rows && result.rows.length > 0) {
|
||||||
result.rows.forEach(row => {
|
result.rows.forEach(row => {
|
||||||
list.push(this.loadModelData(row));
|
list.push(this.loadModelData(row, table));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -82,10 +86,13 @@ export class DeployedModelService {
|
|||||||
* @param filePath model file path
|
* @param filePath model file path
|
||||||
* @param details model details
|
* @param details model details
|
||||||
*/
|
*/
|
||||||
public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
public async deployLocalModel(filePath: string, details: RegisteredModelDetails | undefined, table: DatabaseTable) {
|
||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
if (connection) {
|
if (connection && table.databaseName) {
|
||||||
let currentModels = await this.getDeployedModels();
|
|
||||||
|
await this.configureImport(connection, table);
|
||||||
|
|
||||||
|
let currentModels = await this.getDeployedModels(table);
|
||||||
const content = await utils.readFileInHex(filePath);
|
const content = await utils.readFileInHex(filePath);
|
||||||
const fileName = details?.fileName || utils.getFileName(filePath);
|
const fileName = details?.fileName || utils.getFileName(filePath);
|
||||||
let modelToAdd: RegisteredModel = {
|
let modelToAdd: RegisteredModel = {
|
||||||
@@ -94,25 +101,92 @@ export class DeployedModelService {
|
|||||||
content: content,
|
content: content,
|
||||||
title: details?.title || fileName,
|
title: details?.title || fileName,
|
||||||
description: details?.description,
|
description: details?.description,
|
||||||
version: details?.version
|
version: details?.version,
|
||||||
|
table: table
|
||||||
};
|
};
|
||||||
await this._queryRunner.safeRunQuery(connection, this.getInsertModelQuery(connection.databaseName, modelToAdd));
|
await this._queryRunner.runWithDatabaseChange(connection, this.getInsertModelQuery(modelToAdd, table), table.databaseName);
|
||||||
|
|
||||||
let updatedModels = await this.getDeployedModels();
|
let updatedModels = await this.getDeployedModels(table);
|
||||||
if (updatedModels.length < currentModels.length + 1) {
|
if (updatedModels.length < currentModels.length + 1) {
|
||||||
throw Error(constants.importModelFailedError(details?.title, filePath));
|
throw Error(constants.importModelFailedError(details?.title, filePath));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw new Error(constants.noConnectionError);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private loadModelData(row: azdata.DbCellValue[]): RegisteredModel {
|
|
||||||
|
public async configureImport(connection: azdata.connection.ConnectionProfile, table: DatabaseTable) {
|
||||||
|
if (connection && table.databaseName) {
|
||||||
|
let query = this.getDatabaseConfigureQuery(table);
|
||||||
|
await this._queryRunner.safeRunQuery(connection, query);
|
||||||
|
|
||||||
|
query = this.getConfigureTableQuery(table);
|
||||||
|
await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verifies if the given table name is valid to be used as import table. If table doesn't exist returns true to create new table
|
||||||
|
* Otherwise verifies the schema and returns true if the schema is supported
|
||||||
|
* @param connection database connection
|
||||||
|
* @param table config table name
|
||||||
|
*/
|
||||||
|
public async verifyConfigTable(table: DatabaseTable): Promise<boolean> {
|
||||||
|
let connection = await this.getCurrentConnection();
|
||||||
|
if (connection && table.databaseName) {
|
||||||
|
let databases = await this._apiWrapper.listDatabases(connection.connectionId);
|
||||||
|
|
||||||
|
// If database exist verify the table schema
|
||||||
|
//
|
||||||
|
if ((await databases).find(x => x === table.databaseName)) {
|
||||||
|
const query = this.getConfigTableVerificationQuery(table);
|
||||||
|
const result = await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
|
||||||
|
return result !== undefined && result.rows.length > 0 && result.rows[0][0].displayValue === '1';
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error(constants.noConnectionError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async getRecentImportTable(): Promise<DatabaseTable> {
|
||||||
|
let connection = await this.getCurrentConnection();
|
||||||
|
let table: DatabaseTable | undefined;
|
||||||
|
if (connection) {
|
||||||
|
table = this._recentModelService.getModelTable(connection);
|
||||||
|
if (!table) {
|
||||||
|
table = {
|
||||||
|
databaseName: connection.databaseName ?? 'master',
|
||||||
|
tableName: this._config.registeredModelTableName,
|
||||||
|
schema: this._config.registeredModelTableSchemaName
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error(constants.noConnectionError);
|
||||||
|
}
|
||||||
|
return table;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async storeRecentImportTable(importTable: DatabaseTable): Promise<void> {
|
||||||
|
let connection = await this.getCurrentConnection();
|
||||||
|
if (connection) {
|
||||||
|
this._recentModelService.storeModelTable(connection, importTable);
|
||||||
|
} else {
|
||||||
|
throw new Error(constants.noConnectionError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): RegisteredModel {
|
||||||
return {
|
return {
|
||||||
id: +row[0].displayValue,
|
id: +row[0].displayValue,
|
||||||
artifactName: row[1].displayValue,
|
artifactName: row[1].displayValue,
|
||||||
title: row[2].displayValue,
|
title: row[2].displayValue,
|
||||||
description: row[3].displayValue,
|
description: row[3].displayValue,
|
||||||
version: row[4].displayValue,
|
version: row[4].displayValue,
|
||||||
created: row[5].displayValue
|
created: row[5].displayValue,
|
||||||
|
table: table
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,87 +194,138 @@ export class DeployedModelService {
|
|||||||
return await this._apiWrapper.getCurrentConnection();
|
return await this._apiWrapper.getCurrentConnection();
|
||||||
}
|
}
|
||||||
|
|
||||||
public getConfigureQuery(currentDatabaseName: string): string {
|
public getDatabaseConfigureQuery(configTable: DatabaseTable): string {
|
||||||
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.getConfigureTableQuery());
|
return `
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT name
|
||||||
|
FROM sys.databases
|
||||||
|
WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}'
|
||||||
|
)
|
||||||
|
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}]
|
||||||
|
`;
|
||||||
}
|
}
|
||||||
|
|
||||||
public getDeployedModelsQuery(): string {
|
public getDeployedModelsQuery(table: DatabaseTable): string {
|
||||||
return `
|
return `
|
||||||
SELECT artifact_id, artifact_name, name, description, version, created
|
SELECT artifact_id, artifact_name, name, description, version, created
|
||||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')}
|
||||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||||
Order by artifact_id
|
Order by artifact_id
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verifies config table has the expected schema
|
||||||
|
* @param databaseName
|
||||||
|
* @param tableName
|
||||||
|
*/
|
||||||
|
public getConfigTableVerificationQuery(table: DatabaseTable): string {
|
||||||
|
let tableName = table.tableName;
|
||||||
|
let schemaName = table.schema;
|
||||||
|
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
|
||||||
|
|
||||||
|
return `
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT name
|
||||||
|
FROM sys.databases
|
||||||
|
WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}'
|
||||||
|
)
|
||||||
|
BEGIN
|
||||||
|
Select 1
|
||||||
|
END
|
||||||
|
ELSE
|
||||||
|
BEGIN
|
||||||
|
USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}]
|
||||||
|
IF EXISTS
|
||||||
|
( SELECT t.name, s.name
|
||||||
|
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||||
|
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||||
|
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
||||||
|
)
|
||||||
|
BEGIN
|
||||||
|
IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version')
|
||||||
|
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created')
|
||||||
|
BEGIN
|
||||||
|
Select 1
|
||||||
|
END
|
||||||
|
ELSE
|
||||||
|
BEGIN
|
||||||
|
Select 0
|
||||||
|
END
|
||||||
|
END
|
||||||
|
ELSE
|
||||||
|
select 1
|
||||||
|
END
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the table and adds extra columns (name, description, version) if doesn't already exist.
|
* Update the table and adds extra columns (name, description, version) if doesn't already exist.
|
||||||
* Note: this code is temporary and will be removed weh the table supports the required schema
|
* Note: this code is temporary and will be removed weh the table supports the required schema
|
||||||
* @param databaseName
|
* @param databaseName
|
||||||
* @param tableName
|
* @param tableName
|
||||||
*/
|
*/
|
||||||
public getConfigureTableQuery(): string {
|
public getConfigureTableQuery(table: DatabaseTable): string {
|
||||||
let databaseName = this._config.registeredModelDatabaseName;
|
let tableName = table.tableName;
|
||||||
let tableName = this._config.registeredModelTableName;
|
let schemaName = table.schema;
|
||||||
let schemaName = this._config.registeredModelTableSchemaName;
|
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
|
||||||
|
|
||||||
return `
|
return `
|
||||||
IF NOT EXISTS (
|
|
||||||
SELECT [name]
|
|
||||||
FROM sys.databases
|
|
||||||
WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}'
|
|
||||||
)
|
|
||||||
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
|
||||||
GO
|
|
||||||
USE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
|
||||||
IF EXISTS
|
IF EXISTS
|
||||||
( SELECT [t.name], [s.name]
|
( SELECT t.name, s.name
|
||||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||||
WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||||
AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
||||||
)
|
)
|
||||||
BEGIN
|
BEGIN
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_name')
|
||||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL
|
ALTER TABLE ${twoPartTableName} ADD [artifact_name] [varchar](256) NOT NULL
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='artifact_content')
|
||||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL
|
ALTER TABLE ${twoPartTableName} ADD [artifact_content] [varbinary](max) NOT NULL
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='name')
|
||||||
|
ALTER TABLE ${twoPartTableName} ADD [name] [varchar](256) NULL
|
||||||
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='version')
|
||||||
|
ALTER TABLE ${twoPartTableName} ADD [version] [varchar](256) NULL
|
||||||
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='created')
|
||||||
BEGIN
|
BEGIN
|
||||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL
|
ALTER TABLE ${twoPartTableName} ADD [created] [datetime] NULL
|
||||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
ALTER TABLE ${twoPartTableName} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||||
END
|
END
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='description')
|
||||||
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL
|
ALTER TABLE ${twoPartTableName} ADD [description] [varchar](256) NULL
|
||||||
END
|
END
|
||||||
Else
|
Else
|
||||||
BEGIN
|
BEGIN
|
||||||
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
|
CREATE TABLE ${twoPartTableName}(
|
||||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||||
[artifact_name] [varchar](256) NOT NULL,
|
[artifact_name] [varchar](256) NOT NULL,
|
||||||
[group_path] [varchar](256) NULL,
|
|
||||||
[artifact_content] [varbinary](max) NOT NULL,
|
[artifact_content] [varbinary](max) NOT NULL,
|
||||||
[artifact_initial_size] [bigint] NULL,
|
[artifact_initial_size] [bigint] NULL,
|
||||||
[name] [varchar](256) NULL,
|
[name] [varchar](256) NULL,
|
||||||
[version] [varchar](256) NULL,
|
[version] [varchar](256) NULL,
|
||||||
[created] [datetime] NULL,
|
[created] [datetime] NULL,
|
||||||
[description] [varchar](256) NULL,
|
[description] [varchar](256) NULL,
|
||||||
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_artifact_pk] PRIMARY KEY CLUSTERED
|
||||||
(
|
(
|
||||||
[artifact_id] ASC
|
[artifact_id] ASC
|
||||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||||
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
ALTER TABLE [dbo].[${utils.doubleEscapeSingleBrackets(tableName)}] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||||
END
|
END
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|
||||||
public getInsertModelQuery(currentDatabaseName: string, model: RegisteredModel): string {
|
public getInsertModelQuery(model: RegisteredModel, table: DatabaseTable): string {
|
||||||
|
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
|
||||||
|
const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '');
|
||||||
let updateScript = `
|
let updateScript = `
|
||||||
Insert into ${utils.getRegisteredModelsTowPartsName(this._config)}
|
Insert into ${twoPartTableName}
|
||||||
(artifact_name, group_path, artifact_content, name, version, description)
|
(artifact_name, artifact_content, name, version, description)
|
||||||
values (
|
values (
|
||||||
'${utils.doubleEscapeSingleQuotes(model.artifactName || '')}',
|
'${utils.doubleEscapeSingleQuotes(model.artifactName || '')}',
|
||||||
'ADS',
|
|
||||||
${utils.doubleEscapeSingleQuotes(model.content || '')},
|
${utils.doubleEscapeSingleQuotes(model.content || '')},
|
||||||
'${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
'${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||||
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||||
@@ -208,17 +333,19 @@ export class DeployedModelService {
|
|||||||
`;
|
`;
|
||||||
|
|
||||||
return `
|
return `
|
||||||
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
|
${updateScript}
|
||||||
|
|
||||||
SELECT artifact_id, artifact_name, name, description, version, created
|
SELECT artifact_id, artifact_name, name, description, version, created
|
||||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
FROM ${threePartTableName}
|
||||||
WHERE artifact_id = SCOPE_IDENTITY();
|
WHERE artifact_id = SCOPE_IDENTITY();
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|
||||||
public getModelContentQuery(model: RegisteredModel): string {
|
public getModelContentQuery(model: RegisteredModel): string {
|
||||||
|
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
|
||||||
return `
|
return `
|
||||||
SELECT artifact_content
|
SELECT artifact_content
|
||||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
FROM ${threePartTableName}
|
||||||
WHERE artifact_id = ${model.id};
|
WHERE artifact_id = ${model.id};
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import * as msRest from '@azure/ms-rest-js';
|
import * as msRest from '@azure/ms-rest-js';
|
||||||
import { Resource } from '@azure/arm-machinelearningservices/esm/models';
|
import { Resource } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
|
import { DatabaseTable } from '../prediction/interfaces';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An interface representing ListWorkspaceModelResult.
|
* An interface representing ListWorkspaceModelResult.
|
||||||
@@ -52,6 +53,7 @@ export interface RegisteredModel extends RegisteredModelDetails {
|
|||||||
id: number;
|
id: number;
|
||||||
artifactName: string;
|
artifactName: string;
|
||||||
content?: string;
|
content?: string;
|
||||||
|
table: DatabaseTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ModelParameter {
|
export interface ModelParameter {
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as vscode from 'vscode';
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import { DatabaseTable } from '../prediction/interfaces';
|
||||||
|
|
||||||
|
const TableConfigName = 'MLS_ModelTableConfigName';
|
||||||
|
|
||||||
|
export class ModelConfigRecent {
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
constructor(private _memento: vscode.Memento) {
|
||||||
|
}
|
||||||
|
|
||||||
|
public getModelTable(connection: azdata.connection.ConnectionProfile): DatabaseTable | undefined {
|
||||||
|
return this._memento.get<DatabaseTable>(this.getKey(connection));
|
||||||
|
}
|
||||||
|
|
||||||
|
public storeModelTable(connection: azdata.connection.ConnectionProfile, table: DatabaseTable): void {
|
||||||
|
this._memento.update(this.getKey(connection), table);
|
||||||
|
}
|
||||||
|
|
||||||
|
private getKey(connection: azdata.connection.ConnectionProfile): string {
|
||||||
|
return `${TableConfigName}_${connection.serverName}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,6 @@ import { QueryRunner } from '../common/queryRunner';
|
|||||||
import * as utils from '../common/utils';
|
import * as utils from '../common/utils';
|
||||||
import { RegisteredModel } from '../modelManagement/interfaces';
|
import { RegisteredModel } from '../modelManagement/interfaces';
|
||||||
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
|
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
|
||||||
import { Config } from '../configurations/config';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Service to make prediction
|
* Service to make prediction
|
||||||
@@ -22,8 +21,7 @@ export class PredictService {
|
|||||||
*/
|
*/
|
||||||
constructor(
|
constructor(
|
||||||
private _apiWrapper: ApiWrapper,
|
private _apiWrapper: ApiWrapper,
|
||||||
private _queryRunner: QueryRunner,
|
private _queryRunner: QueryRunner) {
|
||||||
private _config: Config) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -54,7 +52,8 @@ export class PredictService {
|
|||||||
registeredModel.id,
|
registeredModel.id,
|
||||||
predictParams.inputColumns || [],
|
predictParams.inputColumns || [],
|
||||||
predictParams.outputColumns || [],
|
predictParams.outputColumns || [],
|
||||||
predictParams);
|
predictParams,
|
||||||
|
registeredModel.table);
|
||||||
} else if (filePath) {
|
} else if (filePath) {
|
||||||
let modelBytes = await utils.readFileInHex(filePath || '');
|
let modelBytes = await utils.readFileInHex(filePath || '');
|
||||||
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
|
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
|
||||||
@@ -142,18 +141,20 @@ WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuo
|
|||||||
modelId: number,
|
modelId: number,
|
||||||
columns: PredictColumn[],
|
columns: PredictColumn[],
|
||||||
outputColumns: PredictColumn[],
|
outputColumns: PredictColumn[],
|
||||||
databaseNameTable: DatabaseTable): string {
|
sourceTable: DatabaseTable,
|
||||||
|
importTable: DatabaseTable): string {
|
||||||
|
const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || '');
|
||||||
return `
|
return `
|
||||||
DECLARE @model VARBINARY(max) = (
|
DECLARE @model VARBINARY(max) = (
|
||||||
SELECT artifact_content
|
SELECT artifact_content
|
||||||
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
FROM ${threePartTableName}
|
||||||
WHERE artifact_id = ${modelId}
|
WHERE artifact_id = ${modelId}
|
||||||
);
|
);
|
||||||
WITH predict_input
|
WITH predict_input
|
||||||
AS (
|
AS (
|
||||||
SELECT TOP 1000
|
SELECT TOP 1000
|
||||||
${this.getInputColumnNames(columns, 'pi')}
|
${this.getInputColumnNames(columns, 'pi')}
|
||||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
|
||||||
)
|
)
|
||||||
SELECT
|
SELECT
|
||||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
|
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import * as path from 'path';
|
|||||||
import * as os from 'os';
|
import * as os from 'os';
|
||||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||||
import * as fs from 'fs';
|
import * as fs from 'fs';
|
||||||
|
import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent';
|
||||||
|
import { DatabaseTable } from '../../prediction/interfaces';
|
||||||
|
|
||||||
interface TestContext {
|
interface TestContext {
|
||||||
|
|
||||||
@@ -24,6 +26,8 @@ interface TestContext {
|
|||||||
config: TypeMoq.IMock<Config>;
|
config: TypeMoq.IMock<Config>;
|
||||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||||
modelClient: TypeMoq.IMock<ModelPythonClient>;
|
modelClient: TypeMoq.IMock<ModelPythonClient>;
|
||||||
|
recentModels: TypeMoq.IMock<ModelConfigRecent>;
|
||||||
|
importTable: DatabaseTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
function createContext(): TestContext {
|
function createContext(): TestContext {
|
||||||
@@ -32,7 +36,13 @@ function createContext(): TestContext {
|
|||||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||||
config: TypeMoq.Mock.ofType(Config),
|
config: TypeMoq.Mock.ofType(Config),
|
||||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||||
modelClient: TypeMoq.Mock.ofType(ModelPythonClient)
|
modelClient: TypeMoq.Mock.ofType(ModelPythonClient),
|
||||||
|
recentModels: TypeMoq.Mock.ofType(ModelConfigRecent),
|
||||||
|
importTable: {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,14 +50,20 @@ describe('DeployedModelService', () => {
|
|||||||
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
|
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
|
||||||
const testContext = createContext();
|
const testContext = createContext();
|
||||||
let connection: azdata.connection.ConnectionProfile;
|
let connection: azdata.connection.ConnectionProfile;
|
||||||
|
let importTable: DatabaseTable = {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
};
|
||||||
|
|
||||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||||
let service = new DeployedModelService(
|
let service = new DeployedModelService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
await should(service.getDeployedModels()).rejected();
|
testContext.recentModels.object);
|
||||||
|
await should(service.getDeployedModels(importTable)).rejected();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
|
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
|
||||||
@@ -61,7 +77,9 @@ describe('DeployedModelService', () => {
|
|||||||
title: 'title1',
|
title: 'title1',
|
||||||
description: 'desc1',
|
description: 'desc1',
|
||||||
created: '2018-01-01',
|
created: '2018-01-01',
|
||||||
version: '1.1'
|
version: '1.1',
|
||||||
|
table: testContext.importTable
|
||||||
|
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
const result = {
|
const result = {
|
||||||
@@ -106,12 +124,13 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
|
testContext.recentModels.object);
|
||||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||||
|
|
||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||||
const actual = await service.getDeployedModels();
|
const actual = await service.getDeployedModels(testContext.importTable);
|
||||||
should.deepEqual(actual, expected);
|
should.deepEqual(actual, expected);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -140,7 +159,8 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
|
testContext.recentModels.object);
|
||||||
const actual = await service.loadModelParameters('');
|
const actual = await service.loadModelParameters('');
|
||||||
should.deepEqual(actual, expected);
|
should.deepEqual(actual, expected);
|
||||||
});
|
});
|
||||||
@@ -158,7 +178,8 @@ describe('DeployedModelService', () => {
|
|||||||
title: 'title1',
|
title: 'title1',
|
||||||
description: 'desc1',
|
description: 'desc1',
|
||||||
created: '2018-01-01',
|
created: '2018-01-01',
|
||||||
version: '1.1'
|
version: '1.1',
|
||||||
|
table: testContext.importTable
|
||||||
};
|
};
|
||||||
const result = {
|
const result = {
|
||||||
rowCount: 1,
|
rowCount: 1,
|
||||||
@@ -177,7 +198,8 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
|
testContext.recentModels.object);
|
||||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||||
|
|
||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||||
@@ -198,7 +220,8 @@ describe('DeployedModelService', () => {
|
|||||||
title: 'title1',
|
title: 'title1',
|
||||||
description: 'desc1',
|
description: 'desc1',
|
||||||
created: '2018-01-01',
|
created: '2018-01-01',
|
||||||
version: '1.1'
|
version: '1.1',
|
||||||
|
table: testContext.importTable
|
||||||
};
|
};
|
||||||
const row = [
|
const row = [
|
||||||
{
|
{
|
||||||
@@ -247,15 +270,17 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
|
testContext.recentModels.object);
|
||||||
|
|
||||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0))).returns(() => {
|
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('Insert into') > 0), TypeMoq.It.isAny())).returns(() => {
|
||||||
deployed = true;
|
deployed = true;
|
||||||
return Promise.resolve(result);
|
return Promise.resolve(result);
|
||||||
});
|
});
|
||||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||||
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
||||||
});
|
});
|
||||||
|
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||||
|
|
||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||||
@@ -264,7 +289,7 @@ describe('DeployedModelService', () => {
|
|||||||
try {
|
try {
|
||||||
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||||
await fs.promises.writeFile(tempFilePath, 'test');
|
await fs.promises.writeFile(tempFilePath, 'test');
|
||||||
await should(service.deployLocalModel(tempFilePath, model)).resolved();
|
await should(service.deployLocalModel(tempFilePath, model, testContext.importTable)).resolved();
|
||||||
}
|
}
|
||||||
finally {
|
finally {
|
||||||
await utils.deleteFile(tempFilePath);
|
await utils.deleteFile(tempFilePath);
|
||||||
@@ -273,31 +298,28 @@ describe('DeployedModelService', () => {
|
|||||||
|
|
||||||
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
||||||
const testContext = createContext();
|
const testContext = createContext();
|
||||||
const dbName = 'curre[n]tDb';
|
|
||||||
let service = new DeployedModelService(
|
let service = new DeployedModelService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
testContext.recentModels.object);
|
||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
|
||||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
testContext.importTable.databaseName = 'd[]b';
|
||||||
|
testContext.importTable.tableName = 'ta[b]le';
|
||||||
|
testContext.importTable.schema = 'dbo';
|
||||||
const expected = `
|
const expected = `
|
||||||
IF NOT EXISTS (
|
|
||||||
SELECT [name]
|
|
||||||
FROM sys.databases
|
|
||||||
WHERE [name] = N'd[]b'
|
|
||||||
)
|
|
||||||
CREATE DATABASE [d[[]]b]
|
|
||||||
GO
|
|
||||||
USE [d[[]]b]
|
|
||||||
IF EXISTS
|
IF EXISTS
|
||||||
( SELECT [t.name], [s.name]
|
( SELECT t.name, s.name
|
||||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||||
WHERE [t.name] = 'ta[b]le'
|
WHERE t.name = 'ta[b]le'
|
||||||
AND [s.name] = 'dbo'
|
AND s.name = 'dbo'
|
||||||
)
|
)
|
||||||
BEGIN
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_name')
|
||||||
|
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_name] [varchar](256) NOT NULL
|
||||||
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='artifact_content')
|
||||||
|
ALTER TABLE [dbo].[ta[[b]]le] ADD [artifact_content] [varbinary](max) NOT NULL
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='name')
|
||||||
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
|
ALTER TABLE [dbo].[ta[[b]]le] ADD [name] [varchar](256) NULL
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[ta[[b]]le]') AND NAME='version')
|
||||||
@@ -315,23 +337,22 @@ describe('DeployedModelService', () => {
|
|||||||
CREATE TABLE [dbo].[ta[[b]]le](
|
CREATE TABLE [dbo].[ta[[b]]le](
|
||||||
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||||
[artifact_name] [varchar](256) NOT NULL,
|
[artifact_name] [varchar](256) NOT NULL,
|
||||||
[group_path] [varchar](256) NULL,
|
|
||||||
[artifact_content] [varbinary](max) NOT NULL,
|
[artifact_content] [varbinary](max) NOT NULL,
|
||||||
[artifact_initial_size] [bigint] NULL,
|
[artifact_initial_size] [bigint] NULL,
|
||||||
[name] [varchar](256) NULL,
|
[name] [varchar](256) NULL,
|
||||||
[version] [varchar](256) NULL,
|
[version] [varchar](256) NULL,
|
||||||
[created] [datetime] NULL,
|
[created] [datetime] NULL,
|
||||||
[description] [varchar](256) NULL,
|
[description] [varchar](256) NULL,
|
||||||
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
CONSTRAINT [ta[[b]]le_artifact_pk] PRIMARY KEY CLUSTERED
|
||||||
(
|
(
|
||||||
[artifact_id] ASC
|
[artifact_id] ASC
|
||||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||||
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||||
END
|
END
|
||||||
`;
|
`;
|
||||||
const actual = service.getConfigureQuery(dbName);
|
const actual = service.getConfigureTableQuery(testContext.importTable);
|
||||||
should.equal(actual.indexOf(expected) > 0, true);
|
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
|
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
|
||||||
@@ -340,23 +361,23 @@ describe('DeployedModelService', () => {
|
|||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
testContext.recentModels.object);
|
||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
testContext.importTable.databaseName = 'd[]b';
|
||||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
testContext.importTable.tableName = 'ta[b]le';
|
||||||
|
testContext.importTable.schema = 'dbo';
|
||||||
const expected = `
|
const expected = `
|
||||||
SELECT artifact_id, artifact_name, name, description, version, created
|
SELECT artifact_id, artifact_name, name, description, version, created
|
||||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||||
Order by artifact_id
|
Order by artifact_id
|
||||||
`;
|
`;
|
||||||
const actual = service.getDeployedModelsQuery();
|
const actual = service.getDeployedModelsQuery(testContext.importTable);
|
||||||
should.deepEqual(expected, actual);
|
should.deepEqual(expected, actual);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
|
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
|
||||||
const testContext = createContext();
|
const testContext = createContext();
|
||||||
const dbName = 'curre[n]tDb';
|
|
||||||
const model: RegisteredModel =
|
const model: RegisteredModel =
|
||||||
{
|
{
|
||||||
id: 1,
|
id: 1,
|
||||||
@@ -364,28 +385,27 @@ describe('DeployedModelService', () => {
|
|||||||
title: 'title1',
|
title: 'title1',
|
||||||
description: 'desc1',
|
description: 'desc1',
|
||||||
created: '2018-01-01',
|
created: '2018-01-01',
|
||||||
version: '1.1'
|
version: '1.1',
|
||||||
|
table: testContext.importTable
|
||||||
};
|
};
|
||||||
|
|
||||||
let service = new DeployedModelService(
|
let service = new DeployedModelService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
testContext.recentModels.object);
|
||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
|
||||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
|
||||||
const expected = `
|
const expected = `
|
||||||
Insert into [dbo].[ta[[b]]le]
|
Insert into [dbo].[tb]
|
||||||
(artifact_name, group_path, artifact_content, name, version, description)
|
(artifact_name, artifact_content, name, version, description)
|
||||||
values (
|
values (
|
||||||
'name1',
|
'name1',
|
||||||
'ADS',
|
|
||||||
,
|
,
|
||||||
'title1',
|
'title1',
|
||||||
'1.1',
|
'1.1',
|
||||||
'desc1')`;
|
'desc1')`;
|
||||||
const actual = service.getInsertModelQuery(dbName, model);
|
const actual = service.getInsertModelQuery(model, testContext.importTable);
|
||||||
should.equal(actual.indexOf(expected) > 0, true);
|
should.equal(actual.indexOf(expected) > 0, true);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -398,17 +418,19 @@ describe('DeployedModelService', () => {
|
|||||||
title: 'title1',
|
title: 'title1',
|
||||||
description: 'desc1',
|
description: 'desc1',
|
||||||
created: '2018-01-01',
|
created: '2018-01-01',
|
||||||
version: '1.1'
|
version: '1.1',
|
||||||
|
table: testContext.importTable
|
||||||
};
|
};
|
||||||
|
|
||||||
let service = new DeployedModelService(
|
let service = new DeployedModelService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.config.object,
|
testContext.config.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object,
|
||||||
testContext.modelClient.object);
|
testContext.modelClient.object,
|
||||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'd[]b');
|
testContext.recentModels.object);
|
||||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'ta[b]le');
|
model.table = {
|
||||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo'
|
||||||
|
};
|
||||||
const expected = `
|
const expected = `
|
||||||
SELECT artifact_content
|
SELECT artifact_content
|
||||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import * as vscode from 'vscode';
|
|||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import * as TypeMoq from 'typemoq';
|
import * as TypeMoq from 'typemoq';
|
||||||
import * as should from 'should';
|
import * as should from 'should';
|
||||||
import { Config } from '../../configurations/config';
|
|
||||||
import { PredictService } from '../../prediction/predictService';
|
import { PredictService } from '../../prediction/predictService';
|
||||||
import { QueryRunner } from '../../common/queryRunner';
|
import { QueryRunner } from '../../common/queryRunner';
|
||||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
import { RegisteredModel } from '../../modelManagement/interfaces';
|
||||||
@@ -22,7 +21,7 @@ import * as fs from 'fs';
|
|||||||
interface TestContext {
|
interface TestContext {
|
||||||
|
|
||||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||||
config: TypeMoq.IMock<Config>;
|
importTable: DatabaseTable;
|
||||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,7 +29,11 @@ function createContext(): TestContext {
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||||
config: TypeMoq.Mock.ofType(Config),
|
importTable: {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
},
|
||||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
|
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -49,8 +52,7 @@ describe('PredictService', () => {
|
|||||||
|
|
||||||
let service = new PredictService(
|
let service = new PredictService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object);
|
||||||
testContext.config.object);
|
|
||||||
const actual = await service.getDatabaseList();
|
const actual = await service.getDatabaseList();
|
||||||
should.deepEqual(actual, expected);
|
should.deepEqual(actual, expected);
|
||||||
});
|
});
|
||||||
@@ -102,8 +104,7 @@ describe('PredictService', () => {
|
|||||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||||
let service = new PredictService(
|
let service = new PredictService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object);
|
||||||
testContext.config.object);
|
|
||||||
const actual = await service.getTableList('db1');
|
const actual = await service.getTableList('db1');
|
||||||
should.deepEqual(actual, expected);
|
should.deepEqual(actual, expected);
|
||||||
});
|
});
|
||||||
@@ -160,8 +161,7 @@ describe('PredictService', () => {
|
|||||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||||
let service = new PredictService(
|
let service = new PredictService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object);
|
||||||
testContext.config.object);
|
|
||||||
const actual = await service.getTableColumnsList(table);
|
const actual = await service.getTableColumnsList(table);
|
||||||
should.deepEqual(actual, expected);
|
should.deepEqual(actual, expected);
|
||||||
});
|
});
|
||||||
@@ -201,13 +201,13 @@ describe('PredictService', () => {
|
|||||||
title: 'title1',
|
title: 'title1',
|
||||||
description: 'desc1',
|
description: 'desc1',
|
||||||
created: '2018-01-01',
|
created: '2018-01-01',
|
||||||
version: '1.1'
|
version: '1.1',
|
||||||
|
table: testContext.importTable
|
||||||
};
|
};
|
||||||
|
|
||||||
let service = new PredictService(
|
let service = new PredictService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object);
|
||||||
testContext.config.object);
|
|
||||||
|
|
||||||
const document: vscode.TextDocument = {
|
const document: vscode.TextDocument = {
|
||||||
uri: vscode.Uri.parse('file:///usr/home'),
|
uri: vscode.Uri.parse('file:///usr/home'),
|
||||||
@@ -270,8 +270,7 @@ describe('PredictService', () => {
|
|||||||
|
|
||||||
let service = new PredictService(
|
let service = new PredictService(
|
||||||
testContext.apiWrapper.object,
|
testContext.apiWrapper.object,
|
||||||
testContext.queryRunner.object,
|
testContext.queryRunner.object);
|
||||||
testContext.config.object);
|
|
||||||
|
|
||||||
const document: vscode.TextDocument = {
|
const document: vscode.TextDocument = {
|
||||||
uri: vscode.Uri.parse('file:///usr/home'),
|
uri: vscode.Uri.parse('file:///usr/home'),
|
||||||
|
|||||||
@@ -34,6 +34,6 @@ describe('Dashboard widget', () => {
|
|||||||
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '');
|
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '');
|
||||||
dashboard.register();
|
dashboard.register();
|
||||||
testContext.onClick.fire();
|
testContext.onClick.fire();
|
||||||
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atMostOnce());
|
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce());
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -0,0 +1,170 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import * as should from 'should';
|
||||||
|
import * as TypeMoq from 'typemoq';
|
||||||
|
import 'mocha';
|
||||||
|
import { createContext } from './utils';
|
||||||
|
import { RegisteredModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||||
|
import { azureResource } from '../../../typings/azure-resource';
|
||||||
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
|
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||||
|
import { ModelManagementController } from '../../../views/models/modelManagementController';
|
||||||
|
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||||
|
|
||||||
|
const accounts: azdata.Account[] = [
|
||||||
|
{
|
||||||
|
key: {
|
||||||
|
accountId: '1',
|
||||||
|
providerId: ''
|
||||||
|
},
|
||||||
|
displayInfo: {
|
||||||
|
displayName: 'account',
|
||||||
|
userId: '',
|
||||||
|
accountType: '',
|
||||||
|
contextualDisplayName: ''
|
||||||
|
},
|
||||||
|
isStale: false,
|
||||||
|
properties: []
|
||||||
|
}
|
||||||
|
];
|
||||||
|
const subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||||
|
{
|
||||||
|
name: 'subscription',
|
||||||
|
id: '2'
|
||||||
|
}
|
||||||
|
];
|
||||||
|
const groups: azureResource.AzureResourceResourceGroup[] = [
|
||||||
|
{
|
||||||
|
name: 'group',
|
||||||
|
id: '3'
|
||||||
|
}
|
||||||
|
];
|
||||||
|
const workspaces: Workspace[] = [
|
||||||
|
{
|
||||||
|
name: 'workspace',
|
||||||
|
id: '4'
|
||||||
|
}
|
||||||
|
];
|
||||||
|
const models: WorkspaceModel[] = [
|
||||||
|
{
|
||||||
|
id: '5',
|
||||||
|
name: 'model'
|
||||||
|
}
|
||||||
|
];
|
||||||
|
const localModels: RegisteredModel[] = [
|
||||||
|
{
|
||||||
|
id: 1,
|
||||||
|
artifactName: 'model',
|
||||||
|
title: 'model',
|
||||||
|
table: {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
const dbNames: string[] = [
|
||||||
|
'db1',
|
||||||
|
'db2'
|
||||||
|
];
|
||||||
|
const tableNames: DatabaseTable[] = [
|
||||||
|
{
|
||||||
|
databaseName: 'db1',
|
||||||
|
schema: 'dbo',
|
||||||
|
tableName: 'tb1'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
databaseName: 'db1',
|
||||||
|
tableName: 'tb2',
|
||||||
|
schema: 'dbo'
|
||||||
|
}
|
||||||
|
];
|
||||||
|
const columnNames: TableColumn[] = [
|
||||||
|
{
|
||||||
|
columnName: 'c1',
|
||||||
|
dataType: 'int'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
columnName: 'c2',
|
||||||
|
dataType: 'varchar'
|
||||||
|
}
|
||||||
|
];
|
||||||
|
const modelParameters: ModelParameters = {
|
||||||
|
inputs: [
|
||||||
|
{
|
||||||
|
'name': 'p1',
|
||||||
|
'type': 'int'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'name': 'p2',
|
||||||
|
'type': 'varchar'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
outputs: [
|
||||||
|
{
|
||||||
|
'name': 'o1',
|
||||||
|
'type': 'int'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
};
|
||||||
|
describe('Model Controller', () => {
|
||||||
|
|
||||||
|
it('Should open deploy model wizard successfully ', async function (): Promise<void> {
|
||||||
|
let testContext = createContext();
|
||||||
|
|
||||||
|
|
||||||
|
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
|
||||||
|
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'table',
|
||||||
|
schema: 'dbo'
|
||||||
|
}));
|
||||||
|
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
|
||||||
|
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames));
|
||||||
|
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames));
|
||||||
|
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
|
||||||
|
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
|
||||||
|
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
|
||||||
|
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
|
||||||
|
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
|
||||||
|
|
||||||
|
const view = await controller.registerModel(undefined);
|
||||||
|
should.notEqual(view, undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('Should open predict wizard successfully ', async function (): Promise<void> {
|
||||||
|
let testContext = createContext();
|
||||||
|
|
||||||
|
|
||||||
|
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
|
||||||
|
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'table',
|
||||||
|
schema: 'dbo'
|
||||||
|
}));
|
||||||
|
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
|
||||||
|
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve([
|
||||||
|
'db', 'db1'
|
||||||
|
]));
|
||||||
|
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve([
|
||||||
|
{ tableName: 'tb', databaseName: 'db', schema: 'dbo' }
|
||||||
|
]));
|
||||||
|
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
|
||||||
|
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
|
||||||
|
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
|
||||||
|
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
|
||||||
|
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
|
||||||
|
testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames));
|
||||||
|
testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters));
|
||||||
|
testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
|
||||||
|
testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
|
||||||
|
|
||||||
|
const view = await controller.predictModel();
|
||||||
|
should.notEqual(view, undefined);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -34,6 +34,11 @@ describe('Predict Wizard', () => {
|
|||||||
let testContext = createContext();
|
let testContext = createContext();
|
||||||
|
|
||||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||||
|
view.importTable = {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
};
|
||||||
await view.open();
|
await view.open();
|
||||||
let accounts: azdata.Account[] = [
|
let accounts: azdata.Account[] = [
|
||||||
{
|
{
|
||||||
@@ -79,7 +84,12 @@ describe('Predict Wizard', () => {
|
|||||||
{
|
{
|
||||||
id: 1,
|
id: 1,
|
||||||
artifactName: 'model',
|
artifactName: 'model',
|
||||||
title: 'model'
|
title: 'model',
|
||||||
|
table: {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
const dbNames: string[] = [
|
const dbNames: string[] = [
|
||||||
|
|||||||
@@ -7,21 +7,25 @@ import * as azdata from 'azdata';
|
|||||||
import * as should from 'should';
|
import * as should from 'should';
|
||||||
import 'mocha';
|
import 'mocha';
|
||||||
import { createContext } from './utils';
|
import { createContext } from './utils';
|
||||||
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType } from '../../../views/models/modelViewBase';
|
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase';
|
||||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
import { azureResource } from '../../../typings/azure-resource';
|
import { azureResource } from '../../../typings/azure-resource';
|
||||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
import { ViewBase } from '../../../views/viewBase';
|
import { ViewBase } from '../../../views/viewBase';
|
||||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||||
import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard';
|
import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard';
|
||||||
|
|
||||||
describe('Register Model Wizard', () => {
|
describe('Register Model Wizard', () => {
|
||||||
it('Should create view components successfully ', async function (): Promise<void> {
|
it('Should create view components successfully ', async function (): Promise<void> {
|
||||||
let testContext = createContext();
|
let testContext = createContext();
|
||||||
|
|
||||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
|
||||||
|
view.importTable = {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'table',
|
||||||
|
schema: 'dbo'
|
||||||
|
};
|
||||||
await view.open();
|
await view.open();
|
||||||
await view.refresh();
|
|
||||||
should.notEqual(view.wizardView, undefined);
|
should.notEqual(view.wizardView, undefined);
|
||||||
should.notEqual(view.modelSourcePage, undefined);
|
should.notEqual(view.modelSourcePage, undefined);
|
||||||
});
|
});
|
||||||
@@ -29,7 +33,12 @@ describe('Register Model Wizard', () => {
|
|||||||
it('Should load data successfully ', async function (): Promise<void> {
|
it('Should load data successfully ', async function (): Promise<void> {
|
||||||
let testContext = createContext();
|
let testContext = createContext();
|
||||||
|
|
||||||
let view = new RegisterModelWizard(testContext.apiWrapper.object, '');
|
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
|
||||||
|
view.importTable = {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
};
|
||||||
await view.open();
|
await view.open();
|
||||||
let accounts: azdata.Account[] = [
|
let accounts: azdata.Account[] = [
|
||||||
{
|
{
|
||||||
@@ -75,12 +84,27 @@ describe('Register Model Wizard', () => {
|
|||||||
{
|
{
|
||||||
id: 1,
|
id: 1,
|
||||||
artifactName: 'model',
|
artifactName: 'model',
|
||||||
title: 'model'
|
title: 'model',
|
||||||
|
table: {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
view.on(ListModelsEventName, () => {
|
view.on(ListModelsEventName, () => {
|
||||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
|
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
|
||||||
});
|
});
|
||||||
|
view.on(ListDatabaseNamesEventName, () => {
|
||||||
|
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [
|
||||||
|
'db', 'db1'
|
||||||
|
] });
|
||||||
|
});
|
||||||
|
view.on(ListTableNamesEventName, () => {
|
||||||
|
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [
|
||||||
|
'tb', 'tb1'
|
||||||
|
] });
|
||||||
|
});
|
||||||
view.on(ListAccountsEventName, () => {
|
view.on(ListAccountsEventName, () => {
|
||||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
import * as should from 'should';
|
import * as should from 'should';
|
||||||
import 'mocha';
|
import 'mocha';
|
||||||
import { createContext } from './utils';
|
import { createContext } from './utils';
|
||||||
import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog';
|
import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog';
|
||||||
import { ListModelsEventName } from '../../../views/models/modelViewBase';
|
import { ListModelsEventName } from '../../../views/models/modelViewBase';
|
||||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
import { ViewBase } from '../../../views/viewBase';
|
import { ViewBase } from '../../../views/viewBase';
|
||||||
@@ -15,7 +15,7 @@ describe('Registered Models Dialog', () => {
|
|||||||
it('Should create view components successfully ', async function (): Promise<void> {
|
it('Should create view components successfully ', async function (): Promise<void> {
|
||||||
let testContext = createContext();
|
let testContext = createContext();
|
||||||
|
|
||||||
let view = new RegisteredModelsDialog(testContext.apiWrapper.object, '');
|
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
|
||||||
view.open();
|
view.open();
|
||||||
|
|
||||||
should.notEqual(view.dialogView, undefined);
|
should.notEqual(view.dialogView, undefined);
|
||||||
@@ -25,13 +25,18 @@ describe('Registered Models Dialog', () => {
|
|||||||
it('Should load data successfully ', async function (): Promise<void> {
|
it('Should load data successfully ', async function (): Promise<void> {
|
||||||
let testContext = createContext();
|
let testContext = createContext();
|
||||||
|
|
||||||
let view = new RegisteredModelsDialog(testContext.apiWrapper.object, '');
|
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
|
||||||
view.open();
|
view.open();
|
||||||
let models: RegisteredModel[] = [
|
let models: RegisteredModel[] = [
|
||||||
{
|
{
|
||||||
id: 1,
|
id: 1,
|
||||||
artifactName: 'model',
|
artifactName: 'model',
|
||||||
title: ''
|
title: '',
|
||||||
|
table: {
|
||||||
|
databaseName: 'db',
|
||||||
|
tableName: 'tb',
|
||||||
|
schema: 'dbo'
|
||||||
|
}
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
view.on(ListModelsEventName, () => {
|
view.on(ListModelsEventName, () => {
|
||||||
|
|||||||
@@ -9,11 +9,17 @@ import * as TypeMoq from 'typemoq';
|
|||||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { createViewContext } from '../utils';
|
import { createViewContext } from '../utils';
|
||||||
import { ModelViewBase } from '../../../views/models/modelViewBase';
|
import { ModelViewBase } from '../../../views/models/modelViewBase';
|
||||||
|
import { AzureModelRegistryService } from '../../../modelManagement/azureModelRegistryService';
|
||||||
|
import { DeployedModelService } from '../../../modelManagement/deployedModelService';
|
||||||
|
import { PredictService } from '../../../prediction/predictService';
|
||||||
|
|
||||||
export interface TestContext {
|
export interface TestContext {
|
||||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||||
view: azdata.ModelView;
|
view: azdata.ModelView;
|
||||||
onClick: vscode.EventEmitter<any>;
|
onClick: vscode.EventEmitter<any>;
|
||||||
|
azureModelService: TypeMoq.IMock<AzureModelRegistryService>;
|
||||||
|
deployModelService: TypeMoq.IMock<DeployedModelService>;
|
||||||
|
predictService: TypeMoq.IMock<PredictService>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ParentDialog extends ModelViewBase {
|
export class ParentDialog extends ModelViewBase {
|
||||||
@@ -36,6 +42,9 @@ export function createContext(): TestContext {
|
|||||||
return {
|
return {
|
||||||
apiWrapper: viewTestContext.apiWrapper,
|
apiWrapper: viewTestContext.apiWrapper,
|
||||||
view: viewTestContext.view,
|
view: viewTestContext.view,
|
||||||
onClick: viewTestContext.onClick
|
onClick: viewTestContext.onClick,
|
||||||
|
azureModelService: TypeMoq.Mock.ofType(AzureModelRegistryService),
|
||||||
|
deployModelService: TypeMoq.Mock.ofType(DeployedModelService),
|
||||||
|
predictService: TypeMoq.Mock.ofType(PredictService)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ export function createViewContext(): ViewTestContext {
|
|||||||
let button: azdata.ButtonComponent = Object.assign({}, componentBase, {
|
let button: azdata.ButtonComponent = Object.assign({}, componentBase, {
|
||||||
onDidClick: onClick.event
|
onDidClick: onClick.event
|
||||||
});
|
});
|
||||||
|
let link: azdata.HyperlinkComponent = Object.assign({}, componentBase, {
|
||||||
|
onDidClick: onClick.event,
|
||||||
|
label: '',
|
||||||
|
url: ''
|
||||||
|
});
|
||||||
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
|
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
|
||||||
checked: true,
|
checked: true,
|
||||||
onDidClick: onClick.event
|
onDidClick: onClick.event
|
||||||
@@ -61,6 +66,11 @@ export function createViewContext(): ViewTestContext {
|
|||||||
withProperties: () => buttonBuilder,
|
withProperties: () => buttonBuilder,
|
||||||
withValidation: () => buttonBuilder
|
withValidation: () => buttonBuilder
|
||||||
};
|
};
|
||||||
|
let hyperLinkBuilder: azdata.ComponentBuilder<azdata.HyperlinkComponent> = {
|
||||||
|
component: () => link,
|
||||||
|
withProperties: () => hyperLinkBuilder,
|
||||||
|
withValidation: () => hyperLinkBuilder
|
||||||
|
};
|
||||||
let radioButtonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
|
let radioButtonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
|
||||||
component: () => radioButton,
|
component: () => radioButton,
|
||||||
withProperties: () => radioButtonBuilder,
|
withProperties: () => radioButtonBuilder,
|
||||||
@@ -72,7 +82,7 @@ export function createViewContext(): ViewTestContext {
|
|||||||
withValidation: () => checkBoxBuilder
|
withValidation: () => checkBoxBuilder
|
||||||
};
|
};
|
||||||
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
|
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
|
||||||
onTextChanged: undefined!,
|
onTextChanged: onClick.event!,
|
||||||
onEnterKeyPressed: undefined!,
|
onEnterKeyPressed: undefined!,
|
||||||
value: ''
|
value: ''
|
||||||
});
|
});
|
||||||
@@ -216,7 +226,7 @@ export function createViewContext(): ViewTestContext {
|
|||||||
toolbarContainer: undefined!,
|
toolbarContainer: undefined!,
|
||||||
loadingComponent: () => loadingBuilder,
|
loadingComponent: () => loadingBuilder,
|
||||||
fileBrowserTree: undefined!,
|
fileBrowserTree: undefined!,
|
||||||
hyperlink: undefined!,
|
hyperlink: () => hyperLinkBuilder,
|
||||||
tabbedPanel: undefined!,
|
tabbedPanel: undefined!,
|
||||||
separator: undefined!
|
separator: undefined!
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,21 +10,24 @@ import { ModelViewBase } from '../modelViewBase';
|
|||||||
import { CurrentModelsTable } from './currentModelsTable';
|
import { CurrentModelsTable } from './currentModelsTable';
|
||||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { IPageView } from '../../interfaces';
|
import { IPageView } from '../../interfaces';
|
||||||
|
import { TableSelectionComponent } from '../tableSelectionComponent';
|
||||||
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to render current registered models
|
* View to render current registered models
|
||||||
*/
|
*/
|
||||||
export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
||||||
private _tableComponent: azdata.Component | undefined;
|
private _tableComponent: azdata.Component | undefined;
|
||||||
private _dataTable: CurrentModelsTable | undefined;
|
private _dataTable: CurrentModelsTable | undefined;
|
||||||
private _loader: azdata.LoadingComponent | undefined;
|
private _loader: azdata.LoadingComponent | undefined;
|
||||||
|
private _tableSelectionComponent: TableSelectionComponent | undefined;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param apiWrapper Creates new view
|
* @param apiWrapper Creates new view
|
||||||
* @param parent page parent
|
* @param parent page parent
|
||||||
*/
|
*/
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = false) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,11 +36,17 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
|||||||
* @param modelBuilder register the components
|
* @param modelBuilder register the components
|
||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, false);
|
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
|
||||||
|
this._tableSelectionComponent.registerComponent(modelBuilder);
|
||||||
|
this._tableSelectionComponent.onSelectedChanged(async () => {
|
||||||
|
await this.onTableSelected();
|
||||||
|
});
|
||||||
|
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect);
|
||||||
this._dataTable.registerComponent(modelBuilder);
|
this._dataTable.registerComponent(modelBuilder);
|
||||||
this._tableComponent = this._dataTable.component;
|
this._tableComponent = this._dataTable.component;
|
||||||
|
|
||||||
let formModelBuilder = modelBuilder.formContainer();
|
let formModelBuilder = modelBuilder.formContainer();
|
||||||
|
this._tableSelectionComponent.addComponents(formModelBuilder);
|
||||||
|
|
||||||
if (this._tableComponent) {
|
if (this._tableComponent) {
|
||||||
formModelBuilder.addFormItem({
|
formModelBuilder.addFormItem({
|
||||||
@@ -54,6 +63,20 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
|||||||
return this._loader;
|
return this._loader;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._tableSelectionComponent && this._dataTable) {
|
||||||
|
this._tableSelectionComponent.addComponents(formBuilder);
|
||||||
|
this._dataTable.addComponents(formBuilder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._tableSelectionComponent && this._dataTable) {
|
||||||
|
this._tableSelectionComponent.removeComponents(formBuilder);
|
||||||
|
this._dataTable.removeComponents(formBuilder);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the component
|
* Returns the component
|
||||||
*/
|
*/
|
||||||
@@ -68,6 +91,9 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
|||||||
await this.onLoading();
|
await this.onLoading();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
if (this._tableSelectionComponent) {
|
||||||
|
this._tableSelectionComponent.refresh();
|
||||||
|
}
|
||||||
await this._dataTable?.refresh();
|
await this._dataTable?.refresh();
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
this.showErrorMessage(constants.getErrorMessage(err));
|
this.showErrorMessage(constants.getErrorMessage(err));
|
||||||
@@ -76,6 +102,31 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public get data(): RegisteredModel[] | undefined {
|
||||||
|
return this._dataTable?.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async onTableSelected(): Promise<void> {
|
||||||
|
if (this._tableSelectionComponent?.data) {
|
||||||
|
this.importTable = this._tableSelectionComponent?.data;
|
||||||
|
await this.storeImportConfigTable();
|
||||||
|
await this._dataTable?.refresh();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public get modelTable(): CurrentModelsTable | undefined {
|
||||||
|
return this._dataTable;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* disposes the view
|
||||||
|
*/
|
||||||
|
public async disposeComponent(): Promise<void> {
|
||||||
|
if (this._dataTable) {
|
||||||
|
await this._dataTable.disposeComponent();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns the title of the page
|
* returns the title of the page
|
||||||
*/
|
*/
|
||||||
@@ -134,7 +134,11 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
|
|||||||
if (this._table) {
|
if (this._table) {
|
||||||
let models: RegisteredModel[] | undefined;
|
let models: RegisteredModel[] | undefined;
|
||||||
|
|
||||||
models = await this.listModels();
|
if (this.importTable) {
|
||||||
|
models = await this.listModels(this.importTable);
|
||||||
|
} else {
|
||||||
|
this.showErrorMessage('No import table');
|
||||||
|
}
|
||||||
let tableData: any[][] = [];
|
let tableData: any[][] = [];
|
||||||
|
|
||||||
if (models) {
|
if (models) {
|
||||||
@@ -14,15 +14,17 @@ import { WizardView } from '../../wizardView';
|
|||||||
import { ModelSourcePage } from '../modelSourcePage';
|
import { ModelSourcePage } from '../modelSourcePage';
|
||||||
import { ModelDetailsPage } from '../modelDetailsPage';
|
import { ModelDetailsPage } from '../modelDetailsPage';
|
||||||
import { ModelBrowsePage } from '../modelBrowsePage';
|
import { ModelBrowsePage } from '../modelBrowsePage';
|
||||||
|
import { ModelImportLocationPage } from './modelmportLocationPage';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wizard to register a model
|
* Wizard to register a model
|
||||||
*/
|
*/
|
||||||
export class RegisterModelWizard extends ModelViewBase {
|
export class ImportModelWizard extends ModelViewBase {
|
||||||
|
|
||||||
public modelSourcePage: ModelSourcePage | undefined;
|
public modelSourcePage: ModelSourcePage | undefined;
|
||||||
public modelBrowsePage: ModelBrowsePage | undefined;
|
public modelBrowsePage: ModelBrowsePage | undefined;
|
||||||
public modelDetailsPage: ModelDetailsPage | undefined;
|
public modelDetailsPage: ModelDetailsPage | undefined;
|
||||||
|
public modelImportTargetPage: ModelImportLocationPage | undefined;
|
||||||
public wizardView: WizardView | undefined;
|
public wizardView: WizardView | undefined;
|
||||||
private _parentView: ModelViewBase | undefined;
|
private _parentView: ModelViewBase | undefined;
|
||||||
|
|
||||||
@@ -41,9 +43,10 @@ export class RegisterModelWizard extends ModelViewBase {
|
|||||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
|
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
|
||||||
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
|
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
|
||||||
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this);
|
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this);
|
||||||
|
this.modelImportTargetPage = new ModelImportLocationPage(this._apiWrapper, this);
|
||||||
this.wizardView = new WizardView(this._apiWrapper);
|
this.wizardView = new WizardView(this._apiWrapper);
|
||||||
|
|
||||||
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]);
|
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelImportTargetPage, this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]);
|
||||||
|
|
||||||
this.mainViewPanel = wizard;
|
this.mainViewPanel = wizard;
|
||||||
wizard.doneButton.label = constants.azureRegisterModel;
|
wizard.doneButton.label = constants.azureRegisterModel;
|
||||||
@@ -61,7 +64,8 @@ export class RegisterModelWizard extends ModelViewBase {
|
|||||||
wizard.cancelButton.enabled = true;
|
wizard.cancelButton.enabled = true;
|
||||||
wizard.backButton.enabled = true;
|
wizard.backButton.enabled = true;
|
||||||
if (this._parentView) {
|
if (this._parentView) {
|
||||||
await this._parentView?.refresh();
|
this._parentView.importTable = this.importTable;
|
||||||
|
await this._parentView.refresh();
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
@@ -87,10 +91,11 @@ export class RegisterModelWizard extends ModelViewBase {
|
|||||||
private async registerModel(): Promise<boolean> {
|
private async registerModel(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||||
await this.registerLocalModel(this.modelsViewData);
|
await this.importLocalModel(this.modelsViewData);
|
||||||
} else {
|
} else {
|
||||||
await this.registerAzureModel(this.modelsViewData);
|
await this.importAzureModel(this.modelsViewData);
|
||||||
}
|
}
|
||||||
|
await this.storeImportConfigTable();
|
||||||
this.showInfoMessage(constants.modelRegisteredSuccessfully);
|
this.showInfoMessage(constants.modelRegisteredSuccessfully);
|
||||||
return true;
|
return true;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import { CurrentModelsPage } from './currentModelsPage';
|
import { CurrentModelsComponent } from './currentModelsComponent';
|
||||||
|
|
||||||
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
|
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
|
||||||
import * as constants from '../../../common/constants';
|
import * as constants from '../../../common/constants';
|
||||||
@@ -13,7 +13,7 @@ import { DialogView } from '../../dialogView';
|
|||||||
/**
|
/**
|
||||||
* Dialog to render registered model views
|
* Dialog to render registered model views
|
||||||
*/
|
*/
|
||||||
export class RegisteredModelsDialog extends ModelViewBase {
|
export class ManageModelsDialog extends ModelViewBase {
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
apiWrapper: ApiWrapper,
|
apiWrapper: ApiWrapper,
|
||||||
@@ -22,18 +22,18 @@ export class RegisteredModelsDialog extends ModelViewBase {
|
|||||||
this.dialogView = new DialogView(this._apiWrapper);
|
this.dialogView = new DialogView(this._apiWrapper);
|
||||||
}
|
}
|
||||||
public dialogView: DialogView;
|
public dialogView: DialogView;
|
||||||
public currentLanguagesTab: CurrentModelsPage | undefined;
|
public currentLanguagesTab: CurrentModelsComponent | undefined;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Opens a dialog to manage packages used by notebooks.
|
* Opens a dialog to manage packages used by notebooks.
|
||||||
*/
|
*/
|
||||||
public open(): void {
|
public open(): void {
|
||||||
|
|
||||||
this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this);
|
this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this);
|
||||||
|
|
||||||
let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle);
|
let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle);
|
||||||
registerModelButton.onClick(async () => {
|
registerModelButton.onClick(async () => {
|
||||||
await this.sendDataRequest(RegisterModelEventName);
|
await this.sendDataRequest(RegisterModelEventName, this.currentLanguagesTab?.modelTable?.importTable);
|
||||||
});
|
});
|
||||||
|
|
||||||
let dialog = this.dialogView.createDialog(constants.registerModelTitle, [this.currentLanguagesTab]);
|
let dialog = this.dialogView.createDialog(constants.registerModelTitle, [this.currentLanguagesTab]);
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
|
import * as constants from '../../../common/constants';
|
||||||
|
import { IPageView, IDataComponent } from '../../interfaces';
|
||||||
|
import { TableSelectionComponent } from '../tableSelectionComponent';
|
||||||
|
import { DatabaseTable } from '../../../prediction/interfaces';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View to pick model source
|
||||||
|
*/
|
||||||
|
export class ModelImportLocationPage extends ModelViewBase implements IPageView, IDataComponent<DatabaseTable> {
|
||||||
|
|
||||||
|
private _form: azdata.FormContainer | undefined;
|
||||||
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
|
public tableSelectionComponent: TableSelectionComponent | undefined;
|
||||||
|
|
||||||
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||||
|
super(apiWrapper, parent.root, parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param modelBuilder Register components
|
||||||
|
*/
|
||||||
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
|
||||||
|
this._formBuilder = modelBuilder.formContainer();
|
||||||
|
this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, true);
|
||||||
|
this.tableSelectionComponent.onSelectedChanged(async () => {
|
||||||
|
await this.onTableSelected();
|
||||||
|
});
|
||||||
|
this.tableSelectionComponent.registerComponent(modelBuilder);
|
||||||
|
this.tableSelectionComponent.addComponents(this._formBuilder);
|
||||||
|
this._form = this._formBuilder.component();
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async onTableSelected(): Promise<void> {
|
||||||
|
if (this.tableSelectionComponent?.data) {
|
||||||
|
this.importTable = this.tableSelectionComponent?.data;
|
||||||
|
//this.sendRequest(StoreImportTableEventName, this.importTable);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): DatabaseTable | undefined {
|
||||||
|
return this.tableSelectionComponent?.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the component
|
||||||
|
*/
|
||||||
|
public get component(): azdata.Component | undefined {
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refreshes the view
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
if (this.tableSelectionComponent) {
|
||||||
|
await this.tableSelectionComponent.refresh();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns page title
|
||||||
|
*/
|
||||||
|
public get title(): string {
|
||||||
|
return constants.modelImportTargetPageTitle;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async disposePage(): Promise<void> {
|
||||||
|
}
|
||||||
|
|
||||||
|
public async validate(): Promise<boolean> {
|
||||||
|
let validated = false;
|
||||||
|
|
||||||
|
if (this.data?.databaseName && this.data?.tableName) {
|
||||||
|
validated = true;
|
||||||
|
validated = await this.verifyImportConfigTable(this.data);
|
||||||
|
if (!validated) {
|
||||||
|
this.showErrorMessage(constants.invalidImportTableSchemaError(this.data?.databaseName, this.data?.tableName));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this.showErrorMessage(constants.invalidImportTableError(this.data?.databaseName, this.data?.tableName));
|
||||||
|
}
|
||||||
|
return validated;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,8 +10,8 @@ import * as constants from '../../common/constants';
|
|||||||
import { IPageView, IDataComponent } from '../interfaces';
|
import { IPageView, IDataComponent } from '../interfaces';
|
||||||
import { LocalModelsComponent } from './localModelsComponent';
|
import { LocalModelsComponent } from './localModelsComponent';
|
||||||
import { AzureModelsComponent } from './azureModelsComponent';
|
import { AzureModelsComponent } from './azureModelsComponent';
|
||||||
import { CurrentModelsTable } from './registerModels/currentModelsTable';
|
|
||||||
import * as utils from '../../common/utils';
|
import * as utils from '../../common/utils';
|
||||||
|
import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to pick model source
|
* View to pick model source
|
||||||
@@ -19,10 +19,11 @@ import * as utils from '../../common/utils';
|
|||||||
export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
|
export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
|
||||||
|
|
||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
|
private _title: string = constants.modelSourcePageTitle;
|
||||||
private _formBuilder: azdata.FormBuilder | undefined;
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
public localModelsComponent: LocalModelsComponent | undefined;
|
public localModelsComponent: LocalModelsComponent | undefined;
|
||||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||||
public registeredModelsComponent: CurrentModelsTable | undefined;
|
public registeredModelsComponent: CurrentModelsComponent | undefined;
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
@@ -39,7 +40,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
this.localModelsComponent.registerComponent(modelBuilder);
|
this.localModelsComponent.registerComponent(modelBuilder);
|
||||||
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect);
|
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect);
|
||||||
this.azureModelsComponent.registerComponent(modelBuilder);
|
this.azureModelsComponent.registerComponent(modelBuilder);
|
||||||
this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this, this._multiSelect);
|
this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, this._multiSelect);
|
||||||
this.registeredModelsComponent.registerComponent(modelBuilder);
|
this.registeredModelsComponent.registerComponent(modelBuilder);
|
||||||
this.refresh();
|
this.refresh();
|
||||||
this._form = this._formBuilder.component();
|
this._form = this._formBuilder.component();
|
||||||
@@ -88,16 +89,29 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
this.registeredModelsComponent.addComponents(this._formBuilder);
|
this.registeredModelsComponent.addComponents(this._formBuilder);
|
||||||
await this.registeredModelsComponent.refresh();
|
await this.registeredModelsComponent.refresh();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
this.loadTitle();
|
||||||
|
}
|
||||||
|
|
||||||
|
private loadTitle(): void {
|
||||||
|
if (this.modelSourceType === ModelSourceType.Local) {
|
||||||
|
this._title = 'Upload model file';
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.Azure) {
|
||||||
|
this._title = 'Import from Azure Machine Learning';
|
||||||
|
|
||||||
|
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
|
||||||
|
this._title = 'Select imported model';
|
||||||
|
} else {
|
||||||
|
this._title = constants.modelSourcePageTitle;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns page title
|
* Returns page title
|
||||||
*/
|
*/
|
||||||
public get title(): string {
|
public get title(): string {
|
||||||
return constants.modelSourcePageTitle;
|
return this._title;
|
||||||
}
|
}
|
||||||
|
|
||||||
public validate(): Promise<boolean> {
|
public validate(): Promise<boolean> {
|
||||||
@@ -117,6 +131,10 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
return Promise.resolve(validated);
|
return Promise.resolve(validated);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public onEnter(): Promise<void> {
|
||||||
|
return Promise.resolve();
|
||||||
|
}
|
||||||
|
|
||||||
public async onLeave(): Promise<void> {
|
public async onLeave(): Promise<void> {
|
||||||
this.modelsViewData = [];
|
this.modelsViewData = [];
|
||||||
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
|
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
|
||||||
@@ -128,7 +146,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
modelDetails: {
|
modelDetails: {
|
||||||
title: fileName,
|
title: fileName,
|
||||||
fileName: fileName
|
fileName: fileName
|
||||||
}
|
},
|
||||||
|
targetImportTable: this.importTable
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -147,7 +166,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
modelDetails: {
|
modelDetails: {
|
||||||
title: x.model?.name || '',
|
title: x.model?.name || '',
|
||||||
fileName: x.model?.name
|
fileName: x.model?.name
|
||||||
}
|
},
|
||||||
|
targetImportTable: this.importTable
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -159,7 +179,8 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
modelData: x,
|
modelData: x,
|
||||||
modelDetails: {
|
modelDetails: {
|
||||||
title: ''
|
title: ''
|
||||||
}
|
},
|
||||||
|
targetImportTable: this.importTable
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,15 +12,15 @@ import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
|||||||
import { RegisteredModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces';
|
import { RegisteredModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces';
|
||||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||||
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
||||||
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
|
import { ManageModelsDialog } from './manageModels/manageModelsDialog';
|
||||||
import {
|
import {
|
||||||
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||||
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
|
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
|
||||||
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData
|
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName
|
||||||
} from './modelViewBase';
|
} from './modelViewBase';
|
||||||
import { ControllerBase } from '../controllerBase';
|
import { ControllerBase } from '../controllerBase';
|
||||||
import { RegisterModelWizard } from './registerModels/registerModelWizard';
|
import { ImportModelWizard } from './manageModels/importModelWizard';
|
||||||
import * as fs from 'fs';
|
import * as fs from 'fs';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
import { PredictWizard } from './prediction/predictWizard';
|
import { PredictWizard } from './prediction/predictWizard';
|
||||||
@@ -51,11 +51,16 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
* @param apiWrapper apiWrapper
|
* @param apiWrapper apiWrapper
|
||||||
* @param root root folder path
|
* @param root root folder path
|
||||||
*/
|
*/
|
||||||
public async registerModel(parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
|
public async registerModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
|
||||||
controller = controller || this;
|
controller = controller || this;
|
||||||
apiWrapper = apiWrapper || this._apiWrapper;
|
apiWrapper = apiWrapper || this._apiWrapper;
|
||||||
root = root || this._root;
|
root = root || this._root;
|
||||||
let view = new RegisterModelWizard(apiWrapper, root, parent);
|
let view = new ImportModelWizard(apiWrapper, root, parent);
|
||||||
|
if (importTable) {
|
||||||
|
view.importTable = importTable;
|
||||||
|
} else {
|
||||||
|
view.importTable = await controller._registeredModelService.getRecentImportTable();
|
||||||
|
}
|
||||||
|
|
||||||
controller.registerEvents(view);
|
controller.registerEvents(view);
|
||||||
|
|
||||||
@@ -72,6 +77,7 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
public async predictModel(): Promise<ModelViewBase> {
|
public async predictModel(): Promise<ModelViewBase> {
|
||||||
|
|
||||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||||
|
view.importTable = await this._registeredModelService.getRecentImportTable();
|
||||||
|
|
||||||
this.registerEvents(view);
|
this.registerEvents(view);
|
||||||
view.on(LoadModelParametersEventName, async () => {
|
view.on(LoadModelParametersEventName, async () => {
|
||||||
@@ -117,17 +123,18 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService
|
await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService
|
||||||
, azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace);
|
, azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace);
|
||||||
});
|
});
|
||||||
|
view.on(ListModelsEventName, async (args) => {
|
||||||
view.on(ListModelsEventName, async () => {
|
const table = <DatabaseTable>args;
|
||||||
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService);
|
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService, table);
|
||||||
});
|
});
|
||||||
view.on(RegisterLocalModelEventName, async (arg) => {
|
view.on(RegisterLocalModelEventName, async (arg) => {
|
||||||
let models = <ModelViewData[]>arg;
|
let models = <ModelViewData[]>arg;
|
||||||
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models);
|
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models);
|
||||||
view.refresh();
|
view.refresh();
|
||||||
});
|
});
|
||||||
view.on(RegisterModelEventName, async () => {
|
view.on(RegisterModelEventName, async (args) => {
|
||||||
await this.executeAction(view, RegisterModelEventName, this.registerModel, view, this, this._apiWrapper, this._root);
|
const importTable = <DatabaseTable>args;
|
||||||
|
await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root);
|
||||||
});
|
});
|
||||||
view.on(RegisterAzureModelEventName, async (arg) => {
|
view.on(RegisterAzureModelEventName, async (arg) => {
|
||||||
let models = <ModelViewData[]>arg;
|
let models = <ModelViewData[]>arg;
|
||||||
@@ -161,6 +168,16 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
|
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
|
||||||
model);
|
model);
|
||||||
});
|
});
|
||||||
|
view.on(StoreImportTableEventName, async (arg) => {
|
||||||
|
let importTable = <DatabaseTable>arg;
|
||||||
|
await this.executeAction(view, StoreImportTableEventName, this.storeImportTable, this._registeredModelService,
|
||||||
|
importTable);
|
||||||
|
});
|
||||||
|
view.on(VerifyImportTableEventName, async (arg) => {
|
||||||
|
let importTable = <DatabaseTable>arg;
|
||||||
|
await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService,
|
||||||
|
importTable);
|
||||||
|
});
|
||||||
view.on(SourceModelSelectedEventName, (arg) => {
|
view.on(SourceModelSelectedEventName, (arg) => {
|
||||||
view.modelSourceType = <ModelSourceType>arg;
|
view.modelSourceType = <ModelSourceType>arg;
|
||||||
view.refresh();
|
view.refresh();
|
||||||
@@ -170,8 +187,14 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
/**
|
/**
|
||||||
* Opens the dialog for model management
|
* Opens the dialog for model management
|
||||||
*/
|
*/
|
||||||
public async manageRegisteredModels(): Promise<ModelViewBase> {
|
public async manageRegisteredModels(importTable?: DatabaseTable): Promise<ModelViewBase> {
|
||||||
let view = new RegisteredModelsDialog(this._apiWrapper, this._root);
|
let view = new ManageModelsDialog(this._apiWrapper, this._root);
|
||||||
|
|
||||||
|
if (importTable) {
|
||||||
|
view.importTable = importTable;
|
||||||
|
} else {
|
||||||
|
view.importTable = await this._registeredModelService.getRecentImportTable();
|
||||||
|
}
|
||||||
|
|
||||||
// Register events
|
// Register events
|
||||||
//
|
//
|
||||||
@@ -202,8 +225,8 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
return await service.getWorkspaces(account, subscription, group);
|
return await service.getWorkspaces(account, subscription, group);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getRegisteredModels(registeredModelService: DeployedModelService): Promise<RegisteredModel[]> {
|
private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise<RegisteredModel[]> {
|
||||||
return registeredModelService.getDeployedModels();
|
return registeredModelService.getDeployedModels(table);
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getAzureModels(
|
private async getAzureModels(
|
||||||
@@ -221,9 +244,13 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise<void> {
|
private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise<void> {
|
||||||
if (models) {
|
if (models) {
|
||||||
await Promise.all(models.map(async (model) => {
|
await Promise.all(models.map(async (model) => {
|
||||||
const localModel = <string>model.modelData;
|
if (model && model.targetImportTable) {
|
||||||
if (localModel) {
|
const localModel = <string>model.modelData;
|
||||||
await service.deployLocalModel(localModel, model.modelDetails);
|
if (localModel) {
|
||||||
|
await service.deployLocalModel(localModel, model.modelDetails, model.targetImportTable);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw Error(constants.invalidModelToRegisterError);
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
} else {
|
} else {
|
||||||
@@ -240,35 +267,39 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
await Promise.all(models.map(async (model) => {
|
await Promise.all(models.map(async (model) => {
|
||||||
const azureModel = <AzureModelResource>model.modelData;
|
if (model && model.targetImportTable) {
|
||||||
if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) {
|
const azureModel = <AzureModelResource>model.modelData;
|
||||||
let filePath: string | undefined;
|
if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) {
|
||||||
try {
|
let filePath: string | undefined;
|
||||||
const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group,
|
try {
|
||||||
azureModel.workspace, azureModel.model);
|
const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group,
|
||||||
if (filePath) {
|
azureModel.workspace, azureModel.model);
|
||||||
await service.deployLocalModel(filePath, model.modelDetails);
|
if (filePath) {
|
||||||
} else {
|
await service.deployLocalModel(filePath, model.modelDetails, model.targetImportTable);
|
||||||
throw Error(constants.invalidModelToRegisterError);
|
} else {
|
||||||
}
|
throw Error(constants.invalidModelToRegisterError);
|
||||||
} finally {
|
}
|
||||||
if (filePath) {
|
} finally {
|
||||||
await fs.promises.unlink(filePath);
|
if (filePath) {
|
||||||
|
await fs.promises.unlink(filePath);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
throw Error(constants.invalidModelToRegisterError);
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
private async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
||||||
return await predictService.getDatabaseList();
|
return await predictService.getDatabaseList();
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
|
private async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
|
||||||
return await predictService.getTableList(databaseName);
|
return await predictService.getTableList(databaseName);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
|
private async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
|
||||||
return await predictService.getTableColumnsList(databaseTable);
|
return await predictService.getTableColumnsList(databaseTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,6 +316,22 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async storeImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<void> {
|
||||||
|
if (table) {
|
||||||
|
await registeredModelService.storeRecentImportTable(table);
|
||||||
|
} else {
|
||||||
|
throw Error(constants.invalidImportTableError(undefined, undefined));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async verifyImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<boolean> {
|
||||||
|
if (table) {
|
||||||
|
return await registeredModelService.verifyConfigTable(table);
|
||||||
|
} else {
|
||||||
|
throw Error(constants.invalidImportTableError(undefined, undefined));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private async downloadRegisteredModel(
|
private async downloadRegisteredModel(
|
||||||
registeredModelService: DeployedModelService,
|
registeredModelService: DeployedModelService,
|
||||||
model: RegisteredModel | undefined): Promise<string> {
|
model: RegisteredModel | undefined): Promise<string> {
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ export interface ModelViewData {
|
|||||||
modelFile?: string;
|
modelFile?: string;
|
||||||
modelData: AzureModelResource | string | RegisteredModel;
|
modelData: AzureModelResource | string | RegisteredModel;
|
||||||
modelDetails?: RegisteredModelDetails;
|
modelDetails?: RegisteredModelDetails;
|
||||||
|
targetImportTable?: DatabaseTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Event names
|
// Event names
|
||||||
@@ -58,6 +59,8 @@ export const PredictModelEventName = 'predictModel';
|
|||||||
export const RegisterModelEventName = 'registerModel';
|
export const RegisterModelEventName = 'registerModel';
|
||||||
export const SourceModelSelectedEventName = 'sourceModelSelected';
|
export const SourceModelSelectedEventName = 'sourceModelSelected';
|
||||||
export const LoadModelParametersEventName = 'loadModelParameters';
|
export const LoadModelParametersEventName = 'loadModelParameters';
|
||||||
|
export const StoreImportTableEventName = 'storeImportTable';
|
||||||
|
export const VerifyImportTableEventName = 'verifyImportTable';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for all model management views
|
* Base class for all model management views
|
||||||
@@ -66,6 +69,7 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
|
|
||||||
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
|
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
|
||||||
private _modelsViewData: ModelViewData[] = [];
|
private _modelsViewData: ModelViewData[] = [];
|
||||||
|
private _importTable: DatabaseTable | undefined;
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
|
||||||
super(apiWrapper, root, parent);
|
super(apiWrapper, root, parent);
|
||||||
@@ -88,7 +92,9 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
PredictModelEventName,
|
PredictModelEventName,
|
||||||
DownloadAzureModelEventName,
|
DownloadAzureModelEventName,
|
||||||
DownloadRegisteredModelEventName,
|
DownloadRegisteredModelEventName,
|
||||||
LoadModelParametersEventName]);
|
LoadModelParametersEventName,
|
||||||
|
StoreImportTableEventName,
|
||||||
|
VerifyImportTableEventName]);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -109,8 +115,8 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
/**
|
/**
|
||||||
* list registered models
|
* list registered models
|
||||||
*/
|
*/
|
||||||
public async listModels(): Promise<RegisteredModel[]> {
|
public async listModels(table: DatabaseTable): Promise<RegisteredModel[]> {
|
||||||
return await this.sendDataRequest(ListModelsEventName);
|
return await this.sendDataRequest(ListModelsEventName, table);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -156,7 +162,7 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
* registers local model
|
* registers local model
|
||||||
* @param localFilePath local file path
|
* @param localFilePath local file path
|
||||||
*/
|
*/
|
||||||
public async registerLocalModel(models: ModelViewData[]): Promise<void> {
|
public async importLocalModel(models: ModelViewData[]): Promise<void> {
|
||||||
return await this.sendDataRequest(RegisterLocalModelEventName, models);
|
return await this.sendDataRequest(RegisterLocalModelEventName, models);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,10 +193,24 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
* registers azure model
|
* registers azure model
|
||||||
* @param args azure resource
|
* @param args azure resource
|
||||||
*/
|
*/
|
||||||
public async registerAzureModel(models: ModelViewData[]): Promise<void> {
|
public async importAzureModel(models: ModelViewData[]): Promise<void> {
|
||||||
return await this.sendDataRequest(RegisterAzureModelEventName, models);
|
return await this.sendDataRequest(RegisterAzureModelEventName, models);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores the name of the table as recent config table for importing models
|
||||||
|
*/
|
||||||
|
public async storeImportConfigTable(): Promise<void> {
|
||||||
|
await this.sendRequest(StoreImportTableEventName, this.importTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Verifies if table is valid to import models to
|
||||||
|
*/
|
||||||
|
public async verifyImportConfigTable(table: DatabaseTable): Promise<boolean> {
|
||||||
|
return await this.sendDataRequest(VerifyImportTableEventName, table);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* registers azure model
|
* registers azure model
|
||||||
* @param args azure resource
|
* @param args azure resource
|
||||||
@@ -240,7 +260,7 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets model source type
|
* Sets model data
|
||||||
*/
|
*/
|
||||||
public set modelsViewData(value: ModelViewData[]) {
|
public set modelsViewData(value: ModelViewData[]) {
|
||||||
if (this.parent) {
|
if (this.parent) {
|
||||||
@@ -251,7 +271,7 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns model source type
|
* Returns model data
|
||||||
*/
|
*/
|
||||||
public get modelsViewData(): ModelViewData[] {
|
public get modelsViewData(): ModelViewData[] {
|
||||||
if (this.parent) {
|
if (this.parent) {
|
||||||
@@ -261,6 +281,28 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets import table
|
||||||
|
*/
|
||||||
|
public set importTable(value: DatabaseTable | undefined) {
|
||||||
|
if (this.parent) {
|
||||||
|
this.parent.importTable = value;
|
||||||
|
} else {
|
||||||
|
this._importTable = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns import table
|
||||||
|
*/
|
||||||
|
public get importTable(): DatabaseTable | undefined {
|
||||||
|
if (this.parent) {
|
||||||
|
return this.parent.importTable;
|
||||||
|
} else {
|
||||||
|
return this._importTable;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* lists azure workspaces
|
* lists azure workspaces
|
||||||
* @param account azure account
|
* @param account azure account
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import { IDataComponent } from '../../interfaces';
|
|||||||
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
|
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
|
||||||
import { ModelParameters } from '../../../modelManagement/interfaces';
|
import { ModelParameters } from '../../../modelManagement/interfaces';
|
||||||
import { ColumnsTable } from './columnsTable';
|
import { ColumnsTable } from './columnsTable';
|
||||||
|
import { TableSelectionComponent } from '../tableSelectionComponent';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to render filters to pick an azure resource
|
* View to render filters to pick an azure resource
|
||||||
@@ -18,14 +19,10 @@ import { ColumnsTable } from './columnsTable';
|
|||||||
export class InputColumnsComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
|
export class InputColumnsComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
|
||||||
|
|
||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
private _databases: azdata.DropDownComponent | undefined;
|
private _tableSelectionComponent: TableSelectionComponent | undefined;
|
||||||
private _tables: azdata.DropDownComponent | undefined;
|
|
||||||
private _columns: ColumnsTable | undefined;
|
private _columns: ColumnsTable | undefined;
|
||||||
private _dbNames: string[] = [];
|
|
||||||
private _tableNames: DatabaseTable[] = [];
|
|
||||||
private _modelParameters: ModelParameters | undefined;
|
private _modelParameters: ModelParameters | undefined;
|
||||||
private _dbTableComponent: azdata.FlexContainer | undefined;
|
|
||||||
private tableMaxLength = this.componentMaxLength * 2 + 70;
|
|
||||||
/**
|
/**
|
||||||
* Creates a new view
|
* Creates a new view
|
||||||
*/
|
*/
|
||||||
@@ -38,53 +35,15 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
|
|||||||
* @param modelBuilder model builder
|
* @param modelBuilder model builder
|
||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
this._databases = modelBuilder.dropDown().withProperties({
|
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
|
||||||
width: this.componentMaxLength
|
this._tableSelectionComponent.registerComponent(modelBuilder);
|
||||||
}).component();
|
this._tableSelectionComponent.onSelectedChanged(async () => {
|
||||||
this._tables = modelBuilder.dropDown().withProperties({
|
|
||||||
width: this.componentMaxLength
|
|
||||||
}).component();
|
|
||||||
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
|
|
||||||
|
|
||||||
this._databases.onValueChanged(async () => {
|
|
||||||
await this.onDatabaseSelected();
|
|
||||||
});
|
|
||||||
|
|
||||||
this._tables.onValueChanged(async () => {
|
|
||||||
await this.onTableSelected();
|
await this.onTableSelected();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
|
||||||
const databaseForm = modelBuilder.formContainer().withFormItems([{
|
|
||||||
title: constants.columnDatabase,
|
|
||||||
component: this._databases
|
|
||||||
}]).withLayout({
|
|
||||||
padding: '0px'
|
|
||||||
}).component();
|
|
||||||
const tableForm = modelBuilder.formContainer().withFormItems([{
|
|
||||||
title: constants.columnTable,
|
|
||||||
component: this._tables
|
|
||||||
}]).withLayout({
|
|
||||||
padding: '0px'
|
|
||||||
}).component();
|
|
||||||
this._dbTableComponent = modelBuilder.flexContainer().withItems([
|
|
||||||
databaseForm,
|
|
||||||
tableForm
|
|
||||||
], {
|
|
||||||
flex: '0 0 auto',
|
|
||||||
CSSStyles: {
|
|
||||||
'align-items': 'flex-start'
|
|
||||||
}
|
|
||||||
}).withLayout({
|
|
||||||
flexFlow: 'row',
|
|
||||||
justifyContent: 'space-between',
|
|
||||||
width: this.tableMaxLength
|
|
||||||
}).component();
|
|
||||||
|
|
||||||
this._form = modelBuilder.formContainer().withFormItems([{
|
this._form = modelBuilder.formContainer().withFormItems([{
|
||||||
title: '',
|
|
||||||
component: this._dbTableComponent
|
|
||||||
}, {
|
|
||||||
title: constants.inputColumns,
|
title: constants.inputColumns,
|
||||||
component: this._columns.component
|
component: this._columns.component
|
||||||
}]).component();
|
}]).component();
|
||||||
@@ -92,10 +51,10 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
|
|||||||
}
|
}
|
||||||
|
|
||||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._columns && this._dbTableComponent) {
|
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
|
||||||
formBuilder.addFormItems([{
|
formBuilder.addFormItems([{
|
||||||
title: '',
|
title: '',
|
||||||
component: this._dbTableComponent
|
component: this._tableSelectionComponent.component
|
||||||
}, {
|
}, {
|
||||||
title: constants.inputColumns,
|
title: constants.inputColumns,
|
||||||
component: this._columns.component
|
component: this._columns.component
|
||||||
@@ -104,10 +63,10 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
|
|||||||
}
|
}
|
||||||
|
|
||||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
if (this._columns && this._dbTableComponent) {
|
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
|
||||||
formBuilder.removeFormItem({
|
formBuilder.removeFormItem({
|
||||||
title: '',
|
title: '',
|
||||||
component: this._dbTableComponent
|
component: this._tableSelectionComponent.component
|
||||||
});
|
});
|
||||||
formBuilder.removeFormItem({
|
formBuilder.removeFormItem({
|
||||||
title: constants.inputColumns,
|
title: constants.inputColumns,
|
||||||
@@ -136,12 +95,9 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
|
|||||||
* loads data in the components
|
* loads data in the components
|
||||||
*/
|
*/
|
||||||
public async loadData(): Promise<void> {
|
public async loadData(): Promise<void> {
|
||||||
this._dbNames = await this.listDatabaseNames();
|
if (this._tableSelectionComponent) {
|
||||||
if (this._databases && this._dbNames && this._dbNames.length > 0) {
|
this._tableSelectionComponent.refresh();
|
||||||
this._databases.values = this._dbNames;
|
|
||||||
this._databases.value = this._dbNames[0];
|
|
||||||
}
|
}
|
||||||
await this.onDatabaseSelected();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public set modelParameters(value: ModelParameters) {
|
public set modelParameters(value: ModelParameters) {
|
||||||
@@ -167,31 +123,14 @@ export class InputColumnsComponent extends ModelViewBase implements IDataCompone
|
|||||||
await this.loadData();
|
await this.loadData();
|
||||||
}
|
}
|
||||||
|
|
||||||
private async onDatabaseSelected(): Promise<void> {
|
|
||||||
this._tableNames = await this.listTableNames(this.databaseName || '');
|
|
||||||
if (this._tables && this._tableNames && this._tableNames.length > 0) {
|
|
||||||
this._tables.values = this._tableNames.map(t => this.getTableFullName(t));
|
|
||||||
this._tables.value = this.getTableFullName(this._tableNames[0]);
|
|
||||||
}
|
|
||||||
await this.onTableSelected();
|
|
||||||
}
|
|
||||||
|
|
||||||
private getTableFullName(table: DatabaseTable): string {
|
|
||||||
return `${table.schema}.${table.tableName}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
private async onTableSelected(): Promise<void> {
|
private async onTableSelected(): Promise<void> {
|
||||||
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
|
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
private get databaseName(): string | undefined {
|
|
||||||
return <string>this._databases?.value;
|
|
||||||
}
|
|
||||||
|
|
||||||
private get databaseTable(): DatabaseTable {
|
private get databaseTable(): DatabaseTable {
|
||||||
let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._tables?.value);
|
let selectedItem = this._tableSelectionComponent?.data;
|
||||||
return {
|
return {
|
||||||
databaseName: this.databaseName,
|
databaseName: selectedItem?.databaseName,
|
||||||
tableName: selectedItem?.tableName,
|
tableName: selectedItem?.tableName,
|
||||||
schema: selectedItem?.schema
|
schema: selectedItem?.schema
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ export class PredictWizard extends ModelViewBase {
|
|||||||
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
||||||
return await this.azureModelsComponent.getDownloadedModel();
|
return await this.azureModelsComponent.getDownloadedModel();
|
||||||
} else if (this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) {
|
} else if (this.modelBrowsePage && this.modelBrowsePage.registeredModelsComponent) {
|
||||||
return await this.modelBrowsePage.registeredModelsComponent.getDownloadedModel();
|
return await this.modelBrowsePage.registeredModelsComponent.modelTable?.getDownloadedModel();
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,213 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import * as vscode from 'vscode';
|
||||||
|
import { ModelViewBase } from './modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
|
import * as constants from '../../common/constants';
|
||||||
|
import { IDataComponent } from '../interfaces';
|
||||||
|
import { DatabaseTable } from '../../prediction/interfaces';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View to render filters to pick an azure resource
|
||||||
|
*/
|
||||||
|
export class TableSelectionComponent extends ModelViewBase implements IDataComponent<DatabaseTable> {
|
||||||
|
|
||||||
|
private _form: azdata.FormContainer | undefined;
|
||||||
|
private _databases: azdata.DropDownComponent | undefined;
|
||||||
|
private _selectedTableName: string = '';
|
||||||
|
private _tables: azdata.DropDownComponent | undefined;
|
||||||
|
private _dbNames: string[] = [];
|
||||||
|
private _tableNames: DatabaseTable[] = [];
|
||||||
|
private _dbTableComponent: azdata.FlexContainer | undefined;
|
||||||
|
private tableMaxLength = this.componentMaxLength * 2 + 70;
|
||||||
|
private _onSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||||
|
public readonly onSelectedChanged: vscode.Event<void> = this._onSelectedChanged.event;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new view
|
||||||
|
*/
|
||||||
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _editable: boolean) {
|
||||||
|
super(apiWrapper, parent.root, parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register components
|
||||||
|
* @param modelBuilder model builder
|
||||||
|
*/
|
||||||
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
this._databases = modelBuilder.dropDown().withProperties({
|
||||||
|
width: this.componentMaxLength,
|
||||||
|
editable: this._editable,
|
||||||
|
fireOnTextChange: this._editable
|
||||||
|
}).component();
|
||||||
|
this._tables = modelBuilder.dropDown().withProperties({
|
||||||
|
width: this.componentMaxLength,
|
||||||
|
editable: this._editable,
|
||||||
|
fireOnTextChange: this._editable
|
||||||
|
}).component();
|
||||||
|
|
||||||
|
this._databases.onValueChanged(async () => {
|
||||||
|
await this.onDatabaseSelected();
|
||||||
|
});
|
||||||
|
|
||||||
|
this._tables.onValueChanged(async (value) => {
|
||||||
|
// There's an issue with dropdown doesn't set the value in editable mode. this is the workaround
|
||||||
|
|
||||||
|
if (this._tables && value) {
|
||||||
|
this._selectedTableName = this._editable ? value : value.selected;
|
||||||
|
}
|
||||||
|
await this.onTableSelected();
|
||||||
|
});
|
||||||
|
|
||||||
|
const databaseForm = modelBuilder.formContainer().withFormItems([{
|
||||||
|
title: constants.columnDatabase,
|
||||||
|
component: this._databases,
|
||||||
|
}]).withLayout({
|
||||||
|
padding: '0px'
|
||||||
|
}).component();
|
||||||
|
const tableForm = modelBuilder.formContainer().withFormItems([{
|
||||||
|
title: constants.columnTable,
|
||||||
|
component: this._tables
|
||||||
|
}]).withLayout({
|
||||||
|
padding: '0px'
|
||||||
|
}).component();
|
||||||
|
this._dbTableComponent = modelBuilder.flexContainer().withItems([
|
||||||
|
databaseForm,
|
||||||
|
tableForm
|
||||||
|
], {
|
||||||
|
flex: '0 0 auto',
|
||||||
|
CSSStyles: {
|
||||||
|
'align-items': 'flex-start'
|
||||||
|
}
|
||||||
|
}).withLayout({
|
||||||
|
flexFlow: 'row',
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
width: this.tableMaxLength
|
||||||
|
}).component();
|
||||||
|
|
||||||
|
this._form = modelBuilder.formContainer().withFormItems([{
|
||||||
|
title: '',
|
||||||
|
component: this._dbTableComponent
|
||||||
|
}]).component();
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._databases && this._tables) {
|
||||||
|
formBuilder.addFormItems([{
|
||||||
|
title: constants.databaseName,
|
||||||
|
component: this._databases
|
||||||
|
}, {
|
||||||
|
title: constants.tableName,
|
||||||
|
component: this._tables
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._databases && this._tables) {
|
||||||
|
formBuilder.removeFormItem({
|
||||||
|
title: constants.databaseName,
|
||||||
|
component: this._databases
|
||||||
|
});
|
||||||
|
formBuilder.removeFormItem({
|
||||||
|
title: constants.tableName,
|
||||||
|
component: this._tables
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the created component
|
||||||
|
*/
|
||||||
|
public get component(): azdata.Component | undefined {
|
||||||
|
return this._dbTableComponent;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): DatabaseTable | undefined {
|
||||||
|
return this.databaseTable;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* loads data in the components
|
||||||
|
*/
|
||||||
|
public async loadData(): Promise<void> {
|
||||||
|
this._dbNames = await this.listDatabaseNames();
|
||||||
|
if (this._databases && this._dbNames && this._dbNames.length > 0) {
|
||||||
|
this._databases.values = this._dbNames;
|
||||||
|
if (this.importTable) {
|
||||||
|
this._databases.value = this.importTable.databaseName;
|
||||||
|
} else {
|
||||||
|
this._databases.value = this._dbNames[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await this.onDatabaseSelected();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* refreshes the view
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
await this.loadData();
|
||||||
|
}
|
||||||
|
|
||||||
|
private async onDatabaseSelected(): Promise<void> {
|
||||||
|
this._tableNames = await this.listTableNames(this.databaseName || '');
|
||||||
|
if (this._tables && this._tableNames && this._tableNames.length > 0) {
|
||||||
|
this._tables.values = this._tableNames.map(t => this.getTableFullName(t));
|
||||||
|
if (this.importTable) {
|
||||||
|
const selectedTable = this._tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema);
|
||||||
|
if (selectedTable) {
|
||||||
|
this._selectedTableName = this.getTableFullName(selectedTable);
|
||||||
|
this._tables.value = this.getTableFullName(selectedTable);
|
||||||
|
} else {
|
||||||
|
this._selectedTableName = this.getTableFullName(this._tableNames[0]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this._selectedTableName = this.getTableFullName(this._tableNames[0]);
|
||||||
|
}
|
||||||
|
this._tables.value = this._selectedTableName;
|
||||||
|
} else if (this._tables) {
|
||||||
|
this._tables.values = [];
|
||||||
|
this._tables.value = '';
|
||||||
|
}
|
||||||
|
await this.onTableSelected();
|
||||||
|
}
|
||||||
|
|
||||||
|
private getTableFullName(table: DatabaseTable): string {
|
||||||
|
return `${table.schema}.${table.tableName}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async onTableSelected(): Promise<void> {
|
||||||
|
this._onSelectedChanged.fire();
|
||||||
|
}
|
||||||
|
|
||||||
|
private get databaseName(): string | undefined {
|
||||||
|
return <string>this._databases?.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
private get databaseTable(): DatabaseTable {
|
||||||
|
let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._selectedTableName);
|
||||||
|
if (!selectedItem) {
|
||||||
|
const value = this._selectedTableName;
|
||||||
|
const parts = value ? value.split('.') : undefined;
|
||||||
|
selectedItem = {
|
||||||
|
databaseName: this.databaseName,
|
||||||
|
tableName: parts && parts.length > 1 ? parts[1] : value,
|
||||||
|
schema: parts && parts.length > 1 ? parts[0] : 'dbo',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
databaseName: this.databaseName,
|
||||||
|
tableName: selectedItem?.tableName,
|
||||||
|
schema: selectedItem?.schema
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -369,7 +369,7 @@ export class DashboardWidget {
|
|||||||
light: this.asAbsolutePath('images/makePredictions.svg'),
|
light: this.asAbsolutePath('images/makePredictions.svg'),
|
||||||
},
|
},
|
||||||
link: '',
|
link: '',
|
||||||
command: constants.mlImportModelCommand
|
command: constants.mlManageModelsCommand
|
||||||
};
|
};
|
||||||
const importModelsButton = this.createTaskButton(view, importMetadata);
|
const importModelsButton = this.createTaskButton(view, importMetadata);
|
||||||
const notebookMetadata: IActionMetadata = {
|
const notebookMetadata: IActionMetadata = {
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ export class WizardView extends MainViewBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public async validate(pageInfo: azdata.window.WizardPageChangeInfo): Promise<boolean> {
|
public async validate(pageInfo: azdata.window.WizardPageChangeInfo): Promise<boolean> {
|
||||||
if (pageInfo.lastPage !== undefined) {
|
if (pageInfo?.lastPage !== undefined) {
|
||||||
let idxLast = pageInfo.lastPage;
|
let idxLast = pageInfo.lastPage;
|
||||||
let lastPage = this._pages[idxLast];
|
let lastPage = this._pages[idxLast];
|
||||||
if (lastPage && lastPage.validate) {
|
if (lastPage && lastPage.validate) {
|
||||||
@@ -86,16 +86,23 @@ export class WizardView extends MainViewBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
|
private async onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
|
||||||
let idxLast = pageInfo.lastPage;
|
if (pageInfo?.lastPage !== undefined) {
|
||||||
let lastPage = this._pages[idxLast];
|
let idxLast = pageInfo.lastPage;
|
||||||
if (lastPage && lastPage.onLeave) {
|
let lastPage = this._pages[idxLast];
|
||||||
await lastPage.onLeave();
|
if (lastPage && lastPage.onLeave) {
|
||||||
|
await lastPage.onLeave();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let idx = pageInfo.newPage;
|
if (pageInfo?.newPage !== undefined) {
|
||||||
let page = this._pages[idx];
|
let idx = pageInfo.newPage;
|
||||||
if (page && page.onEnter) {
|
let page = this._pages[idx];
|
||||||
await page.onEnter();
|
if (page && page.onEnter) {
|
||||||
|
if (this._wizard && this._wizard.pages.length > idx) {
|
||||||
|
this._wizard.pages[idx].title = page.title;
|
||||||
|
}
|
||||||
|
await page.onEnter();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user