mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-17 02:51:36 -05:00
Machine Learning Services Extension - Predict wizard (#9450)
*MLS extension - Added predict wizard
This commit is contained in:
@@ -26,6 +26,7 @@
|
|||||||
"modelManagement": {
|
"modelManagement": {
|
||||||
"registeredModelsDatabaseName": "MlFlowDB",
|
"registeredModelsDatabaseName": "MlFlowDB",
|
||||||
"registeredModelsTableName": "artifacts",
|
"registeredModelsTableName": "artifacts",
|
||||||
|
"registeredModelsTableSchemaName": "dbo",
|
||||||
"amlModelManagementUrl": "modelmanagement.azureml.net",
|
"amlModelManagementUrl": "modelmanagement.azureml.net",
|
||||||
"amlExperienceUrl": "experiments.azureml.net",
|
"amlExperienceUrl": "experiments.azureml.net",
|
||||||
"amlApiVersion": "2018-11-19",
|
"amlApiVersion": "2018-11-19",
|
||||||
|
|||||||
@@ -57,6 +57,10 @@
|
|||||||
"command": "mls.command.managePackages",
|
"command": "mls.command.managePackages",
|
||||||
"title": "%mls.command.managePackages%"
|
"title": "%mls.command.managePackages%"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"command": "mls.command.predictModel",
|
||||||
|
"title": "%mls.command.predictModel%"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"command": "mls.command.manageModels",
|
"command": "mls.command.manageModels",
|
||||||
"title": "%mls.command.manageModels%"
|
"title": "%mls.command.manageModels%"
|
||||||
@@ -110,7 +114,7 @@
|
|||||||
"mls.command.managePackages",
|
"mls.command.managePackages",
|
||||||
"mls.command.manageLanguages",
|
"mls.command.manageLanguages",
|
||||||
"mls.command.manageModels",
|
"mls.command.manageModels",
|
||||||
"mls.command.registerModel"
|
"mls.command.predictModel"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -7,8 +7,9 @@
|
|||||||
"title.endpoints": "Endpoints",
|
"title.endpoints": "Endpoints",
|
||||||
"mls.command.managePackages": "Manage Packages in SQL Server",
|
"mls.command.managePackages": "Manage Packages in SQL Server",
|
||||||
"mls.command.manageLanguages": "Manage External Languages",
|
"mls.command.manageLanguages": "Manage External Languages",
|
||||||
"mls.command.manageModels": "Manage Models",
|
"mls.command.predictModel": "Make prediction",
|
||||||
"mls.command.registerModel": "Register Model",
|
"mls.command.manageModels": "Manage models",
|
||||||
|
"mls.command.registerModel": "Register model",
|
||||||
"mls.command.odbcdriver": "Install ODBC Driver for SQL Server",
|
"mls.command.odbcdriver": "Install ODBC Driver for SQL Server",
|
||||||
"mls.command.mlsdocs": "Machine Learning Services Documentation",
|
"mls.command.mlsdocs": "Machine Learning Services Documentation",
|
||||||
"mls.configuration.title": "Machine Learning Services configurations",
|
"mls.configuration.title": "Machine Learning Services configurations",
|
||||||
|
|||||||
@@ -105,4 +105,28 @@ export class ApiWrapper {
|
|||||||
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
|
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
|
||||||
return vscode.window.showQuickPick(items, options, token);
|
return vscode.window.showQuickPick(items, options, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public listDatabases(connectionId: string): Thenable<string[]> {
|
||||||
|
return azdata.connection.listDatabases(connectionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public openTextDocument(options?: { language?: string; content?: string; }): Thenable<vscode.TextDocument> {
|
||||||
|
return vscode.workspace.openTextDocument(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
public connect(fileUri: string, connectionId: string): Thenable<void> {
|
||||||
|
return azdata.queryeditor.connect(fileUri, connectionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public runQuery(fileUri: string, options?: Map<string, string>, runCurrentQuery?: boolean): void {
|
||||||
|
azdata.queryeditor.runQuery(fileUri, options, runCurrentQuery);
|
||||||
|
}
|
||||||
|
|
||||||
|
public showTextDocument(uri: vscode.Uri, options?: vscode.TextDocumentShowOptions): Thenable<vscode.TextEditor> {
|
||||||
|
return vscode.window.showTextDocument(uri, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
public createButton(label: string, position?: azdata.window.DialogButtonPosition): azdata.window.Button {
|
||||||
|
return azdata.window.createButton(label, position);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups';
|
|||||||
// Tasks, commands
|
// Tasks, commands
|
||||||
//
|
//
|
||||||
export const mlManageLanguagesCommand = 'mls.command.manageLanguages';
|
export const mlManageLanguagesCommand = 'mls.command.manageLanguages';
|
||||||
|
export const mlsPredictModelCommand = 'mls.command.predictModel';
|
||||||
export const mlManageModelsCommand = 'mls.command.manageModels';
|
export const mlManageModelsCommand = 'mls.command.manageModels';
|
||||||
export const mlRegisterModelCommand = 'mls.command.registerModel';
|
export const mlRegisterModelCommand = 'mls.command.registerModel';
|
||||||
export const mlManagePackagesCommand = 'mls.command.managePackages';
|
export const mlManagePackagesCommand = 'mls.command.managePackages';
|
||||||
@@ -116,6 +117,12 @@ export const modelCreated = localize('models.created', "Date Created");
|
|||||||
export const modelVersion = localize('models.version', "Version");
|
export const modelVersion = localize('models.version', "Version");
|
||||||
export const browseModels = localize('models.browseButton', "...");
|
export const browseModels = localize('models.browseButton', "...");
|
||||||
export const azureAccount = localize('models.azureAccount', "Azure account");
|
export const azureAccount = localize('models.azureAccount', "Azure account");
|
||||||
|
export const columnDatabase = localize('predict.columnDatabase', "Database");
|
||||||
|
export const columnTable = localize('predict.columnTable', "Table");
|
||||||
|
export const inputColumns = localize('predict.inputColumns', "Input columns");
|
||||||
|
export const outputColumns = localize('predict.outputColumns', "Output column");
|
||||||
|
export const columnName = localize('predict.columnName', "Name");
|
||||||
|
export const inputName = localize('predict.inputName', "Input Name");
|
||||||
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
|
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
|
||||||
export const azureGroup = localize('models.azureGroup', "Azure resource group");
|
export const azureGroup = localize('models.azureGroup', "Azure resource group");
|
||||||
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
|
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
|
||||||
@@ -125,18 +132,25 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model
|
|||||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Ender model source details");
|
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Ender model source details");
|
||||||
|
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Select input columns");
|
||||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
|
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Provide model details");
|
||||||
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
||||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||||
export const azureRegisterModel = localize('models.azureRegisterModel', "Register");
|
export const azureRegisterModel = localize('models.azureRegisterModel', "Register");
|
||||||
|
export const predictModel = localize('models.predictModel', "Predict");
|
||||||
export const registerModelTitle = localize('models.RegisterWizard', "Register model");
|
export const registerModelTitle = localize('models.RegisterWizard', "Register model");
|
||||||
|
export const makePredictionTitle = localize('models.makePredictionTitle', "Make prediction");
|
||||||
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
|
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
|
||||||
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
|
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
|
||||||
export const localModelSource = localize('models.localModelSource', "Upload file");
|
export const localModelSource = localize('models.localModelSource', "Upload file");
|
||||||
export const azureModelSource = localize('models.azureModelSource', "Import from AzureML registry");
|
export const azureModelSource = localize('models.azureModelSource', "Import from AzureML registry");
|
||||||
|
export const registeredModelsSource = localize('models.registeredModelsSource', "Select managed models");
|
||||||
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
|
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
|
||||||
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
|
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
|
||||||
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
||||||
|
export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict");
|
||||||
|
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
||||||
|
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||||
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");
|
export const importModelFailedError = localize('models.importModelFailedError', "Failed to register the model");
|
||||||
|
|
||||||
|
|||||||
@@ -163,4 +163,21 @@ export class QueryRunner {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Executes the query but doesn't fail it is fails
|
||||||
|
* @param connection SQL connection
|
||||||
|
* @param query query to run
|
||||||
|
*/
|
||||||
|
public async safeRunQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||||
|
try {
|
||||||
|
return await this.runQuery(connection, query);
|
||||||
|
} catch (error) {
|
||||||
|
console.log(error);
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import * as fs from 'fs';
|
|||||||
import * as constants from '../common/constants';
|
import * as constants from '../common/constants';
|
||||||
import { promisify } from 'util';
|
import { promisify } from 'util';
|
||||||
import { ApiWrapper } from './apiWrapper';
|
import { ApiWrapper } from './apiWrapper';
|
||||||
|
import { Config } from '../configurations/config';
|
||||||
|
|
||||||
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
|
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
|
||||||
let tempFilePath: string = '';
|
let tempFilePath: string = '';
|
||||||
@@ -25,6 +26,11 @@ export async function execCommandOnTempFile<T>(content: string, command: (filePa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function readFileInHex(filePath: string): Promise<string> {
|
||||||
|
let buffer = await fs.promises.readFile(filePath);
|
||||||
|
return `0X${buffer.toString('hex')}`;
|
||||||
|
}
|
||||||
|
|
||||||
export async function exists(path: string): Promise<boolean> {
|
export async function exists(path: string): Promise<boolean> {
|
||||||
return promisify(fs.exists)(path);
|
return promisify(fs.exists)(path);
|
||||||
}
|
}
|
||||||
@@ -109,8 +115,8 @@ export function isWindows(): boolean {
|
|||||||
* ' => ''
|
* ' => ''
|
||||||
* @param value The string to escape
|
* @param value The string to escape
|
||||||
*/
|
*/
|
||||||
export function doubleEscapeSingleQuotes(value: string): string {
|
export function doubleEscapeSingleQuotes(value: string | undefined): string {
|
||||||
return value.replace(/'/g, '\'\'');
|
return value ? value.replace(/'/g, '\'\'') : '';
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -118,8 +124,8 @@ export function doubleEscapeSingleQuotes(value: string): string {
|
|||||||
* ' => ''
|
* ' => ''
|
||||||
* @param value The string to escape
|
* @param value The string to escape
|
||||||
*/
|
*/
|
||||||
export function doubleEscapeSingleBrackets(value: string): string {
|
export function doubleEscapeSingleBrackets(value: string | undefined): string {
|
||||||
return value.replace(/\[/g, '[[').replace(/\]/g, ']]');
|
return value ? value.replace(/\[/g, '[[').replace(/\]/g, ']]') : '';
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -176,3 +182,48 @@ export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Pr
|
|||||||
|
|
||||||
return choices[result.label] || false;
|
return choices[result.label] || false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function makeLinuxPath(filePath: string): string {
|
||||||
|
const parts = filePath.split('\\');
|
||||||
|
return parts.join('/');
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param currentDb Wraps the given script with database switch scripts
|
||||||
|
* @param databaseName
|
||||||
|
* @param script
|
||||||
|
*/
|
||||||
|
export function getScriptWithDBChange(currentDb: string, databaseName: string, script: string): string {
|
||||||
|
if (!currentDb) {
|
||||||
|
currentDb = 'master';
|
||||||
|
}
|
||||||
|
let escapedDbName = doubleEscapeSingleBrackets(databaseName);
|
||||||
|
let escapedCurrentDbName = doubleEscapeSingleBrackets(currentDb);
|
||||||
|
return `
|
||||||
|
USE [${escapedDbName}]
|
||||||
|
${script}
|
||||||
|
USE [${escapedCurrentDbName}]
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns full name of model registration table
|
||||||
|
* @param config config
|
||||||
|
*/
|
||||||
|
export function getRegisteredModelsThreePartsName(config: Config) {
|
||||||
|
const dbName = doubleEscapeSingleBrackets(config.registeredModelDatabaseName);
|
||||||
|
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
|
||||||
|
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
||||||
|
return `[${dbName}].${schema}.[${tableName}]`;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns full name of model registration table
|
||||||
|
* @param config config object
|
||||||
|
*/
|
||||||
|
export function getRegisteredModelsTowPartsName(config: Config) {
|
||||||
|
const schema = doubleEscapeSingleBrackets(config.registeredModelTableSchemaName);
|
||||||
|
const tableName = doubleEscapeSingleBrackets(config.registeredModelTableName);
|
||||||
|
return `[${schema}].[${tableName}]`;
|
||||||
|
}
|
||||||
|
|||||||
@@ -82,6 +82,13 @@ export class Config {
|
|||||||
return this._configValues.modelManagement.registeredModelsTableName;
|
return this._configValues.modelManagement.registeredModelsTableName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns registered models table schema name
|
||||||
|
*/
|
||||||
|
public get registeredModelTableSchemaName(): string {
|
||||||
|
return this._configValues.modelManagement.registeredModelsTableSchemaName;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns registered models table name
|
* Returns registered models table name
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import { ModelManagementController } from '../views/models/modelManagementContro
|
|||||||
import { RegisteredModelService } from '../modelManagement/registeredModelService';
|
import { RegisteredModelService } from '../modelManagement/registeredModelService';
|
||||||
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
|
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
|
||||||
import { ModelImporter } from '../modelManagement/modelImporter';
|
import { ModelImporter } from '../modelManagement/modelImporter';
|
||||||
|
import { PredictService } from '../prediction/predictService';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The main controller class that initializes the extension
|
* The main controller class that initializes the extension
|
||||||
@@ -109,7 +110,9 @@ export default class MainController implements vscode.Disposable {
|
|||||||
//
|
//
|
||||||
let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
|
let registeredModelService = new RegisteredModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter);
|
||||||
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
|
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
|
||||||
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, azureModelsService, registeredModelService);
|
let predictService = new PredictService(this._apiWrapper, this._queryRunner, this._config);
|
||||||
|
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
|
||||||
|
azureModelsService, registeredModelService, predictService);
|
||||||
|
|
||||||
this._apiWrapper.registerCommand(constants.mlManageLanguagesCommand, (async () => {
|
this._apiWrapper.registerCommand(constants.mlManageLanguagesCommand, (async () => {
|
||||||
await languageController.manageLanguages();
|
await languageController.manageLanguages();
|
||||||
@@ -120,6 +123,9 @@ export default class MainController implements vscode.Disposable {
|
|||||||
this._apiWrapper.registerCommand(constants.mlRegisterModelCommand, (async () => {
|
this._apiWrapper.registerCommand(constants.mlRegisterModelCommand, (async () => {
|
||||||
await modelManagementController.registerModel();
|
await modelManagementController.registerModel();
|
||||||
}));
|
}));
|
||||||
|
this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => {
|
||||||
|
await modelManagementController.predictModel();
|
||||||
|
}));
|
||||||
this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => {
|
this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => {
|
||||||
await packageManager.installDependencies();
|
await packageManager.installDependencies();
|
||||||
}));
|
}));
|
||||||
@@ -135,6 +141,9 @@ export default class MainController implements vscode.Disposable {
|
|||||||
this._apiWrapper.registerTaskHandler(constants.mlRegisterModelCommand, async () => {
|
this._apiWrapper.registerTaskHandler(constants.mlRegisterModelCommand, async () => {
|
||||||
await modelManagementController.registerModel();
|
await modelManagementController.registerModel();
|
||||||
});
|
});
|
||||||
|
this._apiWrapper.registerTaskHandler(constants.mlsPredictModelCommand, async () => {
|
||||||
|
await modelManagementController.predictModel();
|
||||||
|
});
|
||||||
this._apiWrapper.registerTaskHandler(constants.mlOdbcDriverCommand, async () => {
|
this._apiWrapper.registerTaskHandler(constants.mlOdbcDriverCommand, async () => {
|
||||||
await this.serverConfigManager.openOdbcDriverDocuments();
|
await this.serverConfigManager.openOdbcDriverDocuments();
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -48,13 +48,19 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
|||||||
/**
|
/**
|
||||||
* An interface representing registered model
|
* An interface representing registered model
|
||||||
*/
|
*/
|
||||||
export interface RegisteredModel {
|
export interface RegisteredModel extends RegisteredModelDetails {
|
||||||
id?: number,
|
id: number;
|
||||||
artifactName?: string,
|
artifactName: string;
|
||||||
title?: string,
|
}
|
||||||
created?: string,
|
|
||||||
version?: string
|
/**
|
||||||
description?: string
|
* An interface representing registered model
|
||||||
|
*/
|
||||||
|
export interface RegisteredModelDetails {
|
||||||
|
title: string;
|
||||||
|
created?: string;
|
||||||
|
version?: string;
|
||||||
|
description?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
|||||||
import * as utils from '../common/utils';
|
import * as utils from '../common/utils';
|
||||||
import { PackageManager } from '../packageManagement/packageManager';
|
import { PackageManager } from '../packageManagement/packageManager';
|
||||||
import * as constants from '../common/constants';
|
import * as constants from '../common/constants';
|
||||||
|
import * as os from 'os';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Service to import model to database
|
* Service to import model to database
|
||||||
@@ -39,8 +40,8 @@ export class ModelImporter {
|
|||||||
|
|
||||||
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
protected async executeScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||||
|
|
||||||
const parts = modelFolderPath.split('\\');
|
let home = utils.makeLinuxPath(os.homedir());
|
||||||
modelFolderPath = parts.join('/');
|
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||||
|
|
||||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||||
|
|
||||||
@@ -51,9 +52,12 @@ export class ModelImporter {
|
|||||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
|
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
|
||||||
let scripts: string[] = [
|
let scripts: string[] = [
|
||||||
'import mlflow.onnx',
|
'import mlflow.onnx',
|
||||||
|
`tracking_uri = "file://${home}/mlruns"`,
|
||||||
|
'print(tracking_uri)',
|
||||||
'import onnx',
|
'import onnx',
|
||||||
'from mlflow.tracking.client import MlflowClient',
|
'from mlflow.tracking.client import MlflowClient',
|
||||||
`onx = onnx.load("${modelFolderPath}")`,
|
`onx = onnx.load("${modelFolderPath}")`,
|
||||||
|
`mlflow.set_tracking_uri(tracking_uri)`,
|
||||||
'client = MlflowClient()',
|
'client = MlflowClient()',
|
||||||
`exp_name = "${experimentId}"`,
|
`exp_name = "${experimentId}"`,
|
||||||
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
|
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import { ApiWrapper } from '../common/apiWrapper';
|
|||||||
import * as utils from '../common/utils';
|
import * as utils from '../common/utils';
|
||||||
import { Config } from '../configurations/config';
|
import { Config } from '../configurations/config';
|
||||||
import { QueryRunner } from '../common/queryRunner';
|
import { QueryRunner } from '../common/queryRunner';
|
||||||
import { RegisteredModel } from './interfaces';
|
import { RegisteredModel, RegisteredModelDetails } from './interfaces';
|
||||||
import { ModelImporter } from './modelImporter';
|
import { ModelImporter } from './modelImporter';
|
||||||
import * as constants from '../common/constants';
|
import * as constants from '../common/constants';
|
||||||
|
|
||||||
@@ -32,7 +32,10 @@ export class RegisteredModelService {
|
|||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
let list: RegisteredModel[] = [];
|
let list: RegisteredModel[] = [];
|
||||||
if (connection) {
|
if (connection) {
|
||||||
let result = await this.runRegisteredModelsListQuery(connection);
|
let query = this.getConfigureQuery(connection.databaseName);
|
||||||
|
await this._queryRunner.safeRunQuery(connection, query);
|
||||||
|
query = this.registeredModelsQuery();
|
||||||
|
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||||
if (result && result.rows && result.rows.length > 0) {
|
if (result && result.rows && result.rows.length > 0) {
|
||||||
result.rows.forEach(row => {
|
result.rows.forEach(row => {
|
||||||
list.push(this.loadModelData(row));
|
list.push(this.loadModelData(row));
|
||||||
@@ -57,7 +60,8 @@ export class RegisteredModelService {
|
|||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
let updatedModel: RegisteredModel | undefined = undefined;
|
let updatedModel: RegisteredModel | undefined = undefined;
|
||||||
if (connection) {
|
if (connection) {
|
||||||
let result = await this.runUpdateModelQuery(connection, model);
|
const query = this.getUpdateModelScript(connection.databaseName, model);
|
||||||
|
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||||
if (result && result.rows && result.rows.length > 0) {
|
if (result && result.rows && result.rows.length > 0) {
|
||||||
const row = result.rows[0];
|
const row = result.rows[0];
|
||||||
updatedModel = this.loadModelData(row);
|
updatedModel = this.loadModelData(row);
|
||||||
@@ -66,7 +70,7 @@ export class RegisteredModelService {
|
|||||||
return updatedModel;
|
return updatedModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async registerLocalModel(filePath: string, details: RegisteredModel | undefined) {
|
public async registerLocalModel(filePath: string, details: RegisteredModelDetails | undefined) {
|
||||||
let connection = await this.getCurrentConnection();
|
let connection = await this.getCurrentConnection();
|
||||||
if (connection) {
|
if (connection) {
|
||||||
let currentModels = await this.getRegisteredModels();
|
let currentModels = await this.getRegisteredModels();
|
||||||
@@ -93,35 +97,14 @@ export class RegisteredModelService {
|
|||||||
return await this._apiWrapper.getCurrentConnection();
|
return await this._apiWrapper.getCurrentConnection();
|
||||||
}
|
}
|
||||||
|
|
||||||
private async runRegisteredModelsListQuery(connection: azdata.connection.ConnectionProfile): Promise<azdata.SimpleExecuteResult | undefined> {
|
private getConfigureQuery(currentDatabaseName: string): string {
|
||||||
try {
|
return utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, this.configureTable());
|
||||||
return await this._queryRunner.runQuery(connection, this.registeredModelsQuery(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName));
|
|
||||||
} catch {
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private async runUpdateModelQuery(connection: azdata.connection.ConnectionProfile, model: RegisteredModel): Promise<azdata.SimpleExecuteResult | undefined> {
|
private registeredModelsQuery(): string {
|
||||||
try {
|
|
||||||
return await this._queryRunner.runQuery(connection, this.getUpdateModelScript(connection.databaseName, this._config.registeredModelDatabaseName, this._config.registeredModelTableName, model));
|
|
||||||
} catch {
|
|
||||||
return undefined;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private registeredModelsQuery(currentDatabaseName: string, databaseName: string, tableName: string): string {
|
|
||||||
if (!currentDatabaseName) {
|
|
||||||
currentDatabaseName = 'master';
|
|
||||||
}
|
|
||||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
|
||||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
|
||||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
|
||||||
|
|
||||||
return `
|
return `
|
||||||
${this.configureTable(databaseName, tableName)}
|
|
||||||
USE [${escapedCurrentDbName}]
|
|
||||||
SELECT artifact_id, artifact_name, name, description, version, created
|
SELECT artifact_id, artifact_name, name, description, version, created
|
||||||
FROM [${escapedDbName}].dbo.[${escapedTableName}]
|
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||||
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
WHERE artifact_name not like 'MLmodel' and artifact_name not like 'conda.yaml'
|
||||||
Order by artifact_id
|
Order by artifact_id
|
||||||
`;
|
`;
|
||||||
@@ -133,52 +116,74 @@ export class RegisteredModelService {
|
|||||||
* @param databaseName
|
* @param databaseName
|
||||||
* @param tableName
|
* @param tableName
|
||||||
*/
|
*/
|
||||||
private configureTable(databaseName: string, tableName: string): string {
|
private configureTable(): string {
|
||||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
let databaseName = this._config.registeredModelDatabaseName;
|
||||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
let tableName = this._config.registeredModelTableName;
|
||||||
|
let schemaName = this._config.registeredModelTableSchemaName;
|
||||||
|
|
||||||
return `
|
return `
|
||||||
USE [${escapedDbName}]
|
IF NOT EXISTS (
|
||||||
|
SELECT [name]
|
||||||
|
FROM sys.databases
|
||||||
|
WHERE [name] = N'${utils.doubleEscapeSingleQuotes(databaseName)}'
|
||||||
|
)
|
||||||
|
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
||||||
|
GO
|
||||||
|
USE [${utils.doubleEscapeSingleBrackets(databaseName)}]
|
||||||
IF EXISTS
|
IF EXISTS
|
||||||
( SELECT [name]
|
( SELECT [t.name], [s.name]
|
||||||
FROM sys.tables
|
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||||
WHERE [name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
WHERE [t.name] = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||||
|
AND [s.name] = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
||||||
)
|
)
|
||||||
BEGIN
|
BEGIN
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${escapedTableName}') AND NAME='name')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='name')
|
||||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [name] [varchar](256) NULL
|
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [name] [varchar](256) NULL
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='version')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='version')
|
||||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [version] [varchar](256) NULL
|
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [version] [varchar](256) NULL
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='created')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='created')
|
||||||
BEGIN
|
BEGIN
|
||||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [created] [datetime] NULL
|
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [created] [datetime] NULL
|
||||||
ALTER TABLE [dbo].[${escapedTableName}] ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD CONSTRAINT CONSTRAINT_NAME DEFAULT GETDATE() FOR created
|
||||||
END
|
END
|
||||||
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('[dbo].[${escapedTableName}]') AND NAME='description')
|
IF NOT EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${utils.getRegisteredModelsTowPartsName(this._config)}') AND NAME='description')
|
||||||
ALTER TABLE [dbo].[${escapedTableName}] ADD [description] [varchar](256) NULL
|
ALTER TABLE ${utils.getRegisteredModelsTowPartsName(this._config)} ADD [description] [varchar](256) NULL
|
||||||
|
END
|
||||||
|
Else
|
||||||
|
BEGIN
|
||||||
|
CREATE TABLE ${utils.getRegisteredModelsTowPartsName(this._config)}(
|
||||||
|
[artifact_id] [int] IDENTITY(1,1) NOT NULL,
|
||||||
|
[artifact_name] [varchar](256) NOT NULL,
|
||||||
|
[group_path] [varchar](256) NOT NULL,
|
||||||
|
[artifact_content] [varbinary](max) NOT NULL,
|
||||||
|
[artifact_initial_size] [bigint] NULL,
|
||||||
|
[name] [varchar](256) NULL,
|
||||||
|
[version] [varchar](256) NULL,
|
||||||
|
[created] [datetime] NULL,
|
||||||
|
[description] [varchar](256) NULL,
|
||||||
|
CONSTRAINT [artifact_pk] PRIMARY KEY CLUSTERED
|
||||||
|
(
|
||||||
|
[artifact_id] ASC
|
||||||
|
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||||
|
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||||
|
ALTER TABLE [dbo].[artifacts] ADD CONSTRAINT [CONSTRAINT_NAME] DEFAULT (getdate()) FOR [created]
|
||||||
END
|
END
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|
||||||
private getUpdateModelScript(currentDatabaseName: string, databaseName: string, tableName: string, model: RegisteredModel): string {
|
private getUpdateModelScript(currentDatabaseName: string, model: RegisteredModel): string {
|
||||||
|
let updateScript = `
|
||||||
if (!currentDatabaseName) {
|
UPDATE ${utils.getRegisteredModelsTowPartsName(this._config)}
|
||||||
currentDatabaseName = 'master';
|
|
||||||
}
|
|
||||||
let escapedTableName = utils.doubleEscapeSingleBrackets(tableName);
|
|
||||||
let escapedDbName = utils.doubleEscapeSingleBrackets(databaseName);
|
|
||||||
let escapedCurrentDbName = utils.doubleEscapeSingleBrackets(currentDatabaseName);
|
|
||||||
return `
|
|
||||||
USE [${escapedDbName}]
|
|
||||||
UPDATE ${escapedTableName}
|
|
||||||
SET
|
SET
|
||||||
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
name = '${utils.doubleEscapeSingleQuotes(model.title || '')}',
|
||||||
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||||
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
description = '${utils.doubleEscapeSingleQuotes(model.description || '')}'
|
||||||
WHERE artifact_id = ${model.id};
|
WHERE artifact_id = ${model.id}`;
|
||||||
|
|
||||||
USE [${escapedCurrentDbName}]
|
return `
|
||||||
SELECT artifact_id, artifact_name, name, description, version, created from ${escapedDbName}.dbo.[${escapedTableName}]
|
${utils.getScriptWithDBChange(currentDatabaseName, this._config.registeredModelDatabaseName, updateScript)}
|
||||||
|
SELECT artifact_id, artifact_name, name, description, version, created
|
||||||
|
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||||
WHERE artifact_id = ${model.id};
|
WHERE artifact_id = ${model.id};
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
export interface PredictColumn {
|
||||||
|
name: string;
|
||||||
|
dataType?: string;
|
||||||
|
displayName?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DatabaseTable {
|
||||||
|
databaseName: string | undefined;
|
||||||
|
tableName: string | undefined;
|
||||||
|
schema: string | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PredictInputParameters extends DatabaseTable {
|
||||||
|
inputColumns: PredictColumn[] | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PredictParameters extends PredictInputParameters {
|
||||||
|
outputColumns: PredictColumn[] | undefined
|
||||||
|
}
|
||||||
@@ -0,0 +1,203 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
|
||||||
|
import { ApiWrapper } from '../common/apiWrapper';
|
||||||
|
import { QueryRunner } from '../common/queryRunner';
|
||||||
|
import * as utils from '../common/utils';
|
||||||
|
import { RegisteredModel } from '../modelManagement/interfaces';
|
||||||
|
import { PredictParameters, PredictColumn, DatabaseTable } from '../prediction/interfaces';
|
||||||
|
import { Config } from '../configurations/config';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Service to make prediction
|
||||||
|
*/
|
||||||
|
export class PredictService {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates new instance
|
||||||
|
*/
|
||||||
|
constructor(
|
||||||
|
private _apiWrapper: ApiWrapper,
|
||||||
|
private _queryRunner: QueryRunner,
|
||||||
|
private _config: Config) {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the list of databases
|
||||||
|
*/
|
||||||
|
public async getDatabaseList(): Promise<string[]> {
|
||||||
|
let connection = await this.getCurrentConnection();
|
||||||
|
if (connection) {
|
||||||
|
return await this._apiWrapper.listDatabases(connection.connectionId);
|
||||||
|
}
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates prediction script given model info and predict parameters
|
||||||
|
* @param predictParams predict parameters
|
||||||
|
* @param registeredModel model parameters
|
||||||
|
*/
|
||||||
|
public async generatePredictScript(
|
||||||
|
predictParams: PredictParameters,
|
||||||
|
registeredModel: RegisteredModel | undefined,
|
||||||
|
filePath: string | undefined
|
||||||
|
): Promise<string> {
|
||||||
|
let connection = await this.getCurrentConnection();
|
||||||
|
let query = '';
|
||||||
|
if (registeredModel && registeredModel.id) {
|
||||||
|
query = this.getPredictScriptWithModelId(
|
||||||
|
registeredModel.id,
|
||||||
|
predictParams.inputColumns || [],
|
||||||
|
predictParams.outputColumns || [],
|
||||||
|
predictParams);
|
||||||
|
} else if (filePath) {
|
||||||
|
let modelBytes = await utils.readFileInHex(filePath || '');
|
||||||
|
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
|
||||||
|
predictParams.outputColumns || [],
|
||||||
|
predictParams);
|
||||||
|
}
|
||||||
|
let document = await this._apiWrapper.openTextDocument({
|
||||||
|
language: 'sql',
|
||||||
|
content: query
|
||||||
|
});
|
||||||
|
await this._apiWrapper.showTextDocument(document.uri);
|
||||||
|
await this._apiWrapper.connect(document.uri.toString(), connection.connectionId);
|
||||||
|
this._apiWrapper.runQuery(document.uri.toString(), undefined, false);
|
||||||
|
return query;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns list of tables given database name
|
||||||
|
* @param databaseName database name
|
||||||
|
*/
|
||||||
|
public async getTableList(databaseName: string): Promise<DatabaseTable[]> {
|
||||||
|
let connection = await this.getCurrentConnection();
|
||||||
|
let list: DatabaseTable[] = [];
|
||||||
|
if (connection) {
|
||||||
|
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName));
|
||||||
|
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||||
|
if (result && result.rows && result.rows.length > 0) {
|
||||||
|
result.rows.forEach(row => {
|
||||||
|
list.push({
|
||||||
|
databaseName: databaseName,
|
||||||
|
tableName: row[0].displayValue,
|
||||||
|
schema: row[1].displayValue
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*Returns list of column names of a database
|
||||||
|
* @param databaseTable table info
|
||||||
|
*/
|
||||||
|
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<string[]> {
|
||||||
|
let connection = await this.getCurrentConnection();
|
||||||
|
let list: string[] = [];
|
||||||
|
if (connection && databaseTable.databaseName) {
|
||||||
|
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
|
||||||
|
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||||
|
if (result && result.rows && result.rows.length > 0) {
|
||||||
|
result.rows.forEach(row => {
|
||||||
|
list.push(row[0].displayValue);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||||
|
return await this._apiWrapper.getCurrentConnection();
|
||||||
|
}
|
||||||
|
|
||||||
|
private getTableColumnsScript(databaseTable: DatabaseTable): string {
|
||||||
|
return `
|
||||||
|
SELECT COLUMN_NAME,*
|
||||||
|
FROM INFORMATION_SCHEMA.COLUMNS
|
||||||
|
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
|
||||||
|
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
|
||||||
|
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private getTablesScript(databaseName: string): string {
|
||||||
|
return `
|
||||||
|
SELECT TABLE_NAME,TABLE_SCHEMA
|
||||||
|
FROM INFORMATION_SCHEMA.TABLES
|
||||||
|
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private getPredictScriptWithModelId(
|
||||||
|
modelId: number,
|
||||||
|
columns: PredictColumn[],
|
||||||
|
outputColumns: PredictColumn[],
|
||||||
|
databaseNameTable: DatabaseTable): string {
|
||||||
|
return `
|
||||||
|
DECLARE @model VARBINARY(max) = (
|
||||||
|
SELECT artifact_content
|
||||||
|
FROM ${utils.getRegisteredModelsThreePartsName(this._config)}
|
||||||
|
WHERE artifact_id = ${modelId}
|
||||||
|
);
|
||||||
|
WITH predict_input
|
||||||
|
AS (
|
||||||
|
SELECT TOP 1000
|
||||||
|
${this.getColumnNames(columns, 'pi')}
|
||||||
|
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
|
||||||
|
FROM PREDICT(MODEL = @model, DATA = predict_input)
|
||||||
|
WITH (
|
||||||
|
${this.getColumnTypes(outputColumns)}
|
||||||
|
) AS p
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private getPredictScriptWithModelBytes(
|
||||||
|
modelBytes: string,
|
||||||
|
columns: PredictColumn[],
|
||||||
|
outputColumns: PredictColumn[],
|
||||||
|
databaseNameTable: DatabaseTable): string {
|
||||||
|
return `
|
||||||
|
WITH predict_input
|
||||||
|
AS (
|
||||||
|
SELECT TOP 1000
|
||||||
|
${this.getColumnNames(columns, 'pi')}
|
||||||
|
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
${this.getInputColumnNames(columns, 'predict_input')}, ${this.getColumnNames(outputColumns, 'p')}
|
||||||
|
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
|
||||||
|
WITH (
|
||||||
|
${this.getColumnTypes(outputColumns)}
|
||||||
|
) AS p
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private getColumnNames(columns: PredictColumn[], tableName: string) {
|
||||||
|
return columns.map(c => {
|
||||||
|
return c.displayName ? `${tableName}.${c.name} AS ${c.displayName}` : `${tableName}.${c.name}`;
|
||||||
|
}).join(',\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||||
|
return columns.map(c => {
|
||||||
|
return c.displayName ? `${tableName}.${c.displayName}` : `${tableName}.${c.name}`;
|
||||||
|
}).join(',\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
private getColumnTypes(columns: PredictColumn[]) {
|
||||||
|
return columns.map(c => {
|
||||||
|
return `${c.name} ${c.dataType}`;
|
||||||
|
}).join(',\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -13,7 +13,7 @@ import { azureResource } from '../../../typings/azure-resource';
|
|||||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
import { ViewBase } from '../../../views/viewBase';
|
import { ViewBase } from '../../../views/viewBase';
|
||||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||||
import { RegisterModelWizard } from '../../../views/models/registerModelWizard';
|
import { RegisterModelWizard } from '../../../views/models/registerModels/registerModelWizard';
|
||||||
|
|
||||||
describe('Register Model Wizard', () => {
|
describe('Register Model Wizard', () => {
|
||||||
it('Should create view components successfully ', async function (): Promise<void> {
|
it('Should create view components successfully ', async function (): Promise<void> {
|
||||||
@@ -74,7 +74,8 @@ describe('Register Model Wizard', () => {
|
|||||||
let localModels: RegisteredModel[] = [
|
let localModels: RegisteredModel[] = [
|
||||||
{
|
{
|
||||||
id: 1,
|
id: 1,
|
||||||
artifactName: 'model'
|
artifactName: 'model',
|
||||||
|
title: 'model'
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
view.on(ListModelsEventName, () => {
|
view.on(ListModelsEventName, () => {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
import * as should from 'should';
|
import * as should from 'should';
|
||||||
import 'mocha';
|
import 'mocha';
|
||||||
import { createContext } from './utils';
|
import { createContext } from './utils';
|
||||||
import { RegisteredModelsDialog } from '../../../views/models/registeredModelsDialog';
|
import { RegisteredModelsDialog } from '../../../views/models/registerModels/registeredModelsDialog';
|
||||||
import { ListModelsEventName } from '../../../views/models/modelViewBase';
|
import { ListModelsEventName } from '../../../views/models/modelViewBase';
|
||||||
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
import { ViewBase } from '../../../views/viewBase';
|
import { ViewBase } from '../../../views/viewBase';
|
||||||
@@ -30,7 +30,8 @@ describe('Registered Models Dialog', () => {
|
|||||||
let models: RegisteredModel[] = [
|
let models: RegisteredModel[] = [
|
||||||
{
|
{
|
||||||
id: 1,
|
id: 1,
|
||||||
artifactName: 'model'
|
artifactName: 'model',
|
||||||
|
title: ''
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
view.on(ListModelsEventName, () => {
|
view.on(ListModelsEventName, () => {
|
||||||
|
|||||||
@@ -246,6 +246,7 @@ export function createViewContext(): ViewTestContext {
|
|||||||
modelView: undefined!,
|
modelView: undefined!,
|
||||||
valid: true
|
valid: true
|
||||||
};
|
};
|
||||||
|
apiWrapper.setup(x => x.createButton(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => dialogButton);
|
||||||
apiWrapper.setup(x => x.createTab(TypeMoq.It.isAny())).returns(() => tab);
|
apiWrapper.setup(x => x.createTab(TypeMoq.It.isAny())).returns(() => tab);
|
||||||
apiWrapper.setup(x => x.createWizard(TypeMoq.It.isAny())).returns(() => wizard);
|
apiWrapper.setup(x => x.createWizard(TypeMoq.It.isAny())).returns(() => wizard);
|
||||||
apiWrapper.setup(x => x.createWizardPage(TypeMoq.It.isAny())).returns(() => wizardPage);
|
apiWrapper.setup(x => x.createWizardPage(TypeMoq.It.isAny())).returns(() => wizardPage);
|
||||||
|
|||||||
@@ -3,7 +3,9 @@
|
|||||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import { ViewBase, LocalFileEventName, LocalFolderEventName } from './viewBase';
|
import * as vscode from 'vscode';
|
||||||
|
|
||||||
|
import { ViewBase, LocalPathsEventName } from './viewBase';
|
||||||
import { ApiWrapper } from '../common/apiWrapper';
|
import { ApiWrapper } from '../common/apiWrapper';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -36,11 +38,8 @@ export abstract class ControllerBase {
|
|||||||
* @param view view
|
* @param view view
|
||||||
*/
|
*/
|
||||||
public registerEvents(view: ViewBase): void {
|
public registerEvents(view: ViewBase): void {
|
||||||
view.on(LocalFileEventName, async () => {
|
view.on(LocalPathsEventName, async (args) => {
|
||||||
await this.executeAction(view, LocalFileEventName, this.getLocalFilePath, this._apiWrapper);
|
await this.executeAction(view, LocalPathsEventName, this.getLocalPaths, this._apiWrapper, args);
|
||||||
});
|
|
||||||
view.on(LocalFolderEventName, async () => {
|
|
||||||
await this.executeAction(view, LocalFolderEventName, this.getLocalFolderPath, this._apiWrapper);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,25 +47,8 @@ export abstract class ControllerBase {
|
|||||||
* Returns local file path picked by the user
|
* Returns local file path picked by the user
|
||||||
* @param apiWrapper apiWrapper
|
* @param apiWrapper apiWrapper
|
||||||
*/
|
*/
|
||||||
public async getLocalFilePath(apiWrapper: ApiWrapper): Promise<string> {
|
public async getLocalPaths(apiWrapper: ApiWrapper, options: vscode.OpenDialogOptions): Promise<string[]> {
|
||||||
let result = await apiWrapper.showOpenDialog({
|
let result = await apiWrapper.showOpenDialog(options);
|
||||||
canSelectFiles: true,
|
return result ? result?.map(x => x.fsPath) : [];
|
||||||
canSelectFolders: false,
|
|
||||||
canSelectMany: false
|
|
||||||
});
|
|
||||||
return result && result.length > 0 ? result[0].fsPath : '';
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns local folder path picked by the user
|
|
||||||
* @param apiWrapper apiWrapper
|
|
||||||
*/
|
|
||||||
public async getLocalFolderPath(apiWrapper: ApiWrapper): Promise<string> {
|
|
||||||
let result = await apiWrapper.showOpenDialog({
|
|
||||||
canSelectFiles: false,
|
|
||||||
canSelectFolders: true,
|
|
||||||
canSelectMany: false
|
|
||||||
});
|
|
||||||
return result && result.length > 0 ? result[0].fsPath : '';
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ export interface IPageView {
|
|||||||
component: azdata.Component | undefined;
|
component: azdata.Component | undefined;
|
||||||
onEnter?: () => Promise<void>;
|
onEnter?: () => Promise<void>;
|
||||||
onLeave?: () => Promise<void>;
|
onLeave?: () => Promise<void>;
|
||||||
|
validate?: () => Promise<boolean>;
|
||||||
refresh: () => Promise<void>;
|
refresh: () => Promise<void>;
|
||||||
viewPanel: azdata.window.ModelViewPanel | undefined;
|
viewPanel: azdata.window.ModelViewPanel | undefined;
|
||||||
title: string;
|
title: string;
|
||||||
@@ -32,3 +33,4 @@ export interface AzureModelResource extends AzureWorkspaceResource {
|
|||||||
model?: WorkspaceModel;
|
model?: WorkspaceModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
|
import * as vscode from 'vscode';
|
||||||
|
|
||||||
import { ModelViewBase } from './modelViewBase';
|
import { ModelViewBase } from './modelViewBase';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
@@ -43,9 +45,17 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen
|
|||||||
}
|
}
|
||||||
}).component();
|
}).component();
|
||||||
this._localBrowse.onDidClick(async () => {
|
this._localBrowse.onDidClick(async () => {
|
||||||
const filePath = await this.getLocalFilePath();
|
|
||||||
|
let options: vscode.OpenDialogOptions = {
|
||||||
|
canSelectFiles: true,
|
||||||
|
canSelectFolders: false,
|
||||||
|
canSelectMany: false,
|
||||||
|
filters: { 'ONNX File': ['onnx'] }
|
||||||
|
};
|
||||||
|
|
||||||
|
const filePaths = await this.getLocalPaths(options);
|
||||||
if (this._localPath) {
|
if (this._localPath) {
|
||||||
this._localPath.value = filePath;
|
this._localPath.value = filePaths && filePaths.length > 0 ? filePaths[0] : '';
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import { ModelViewBase } from './modelViewBase';
|
|||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
import { IDataComponent } from '../interfaces';
|
import { IDataComponent } from '../interfaces';
|
||||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to pick local models file
|
* View to pick local models file
|
||||||
*/
|
*/
|
||||||
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModel> {
|
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<RegisteredModelDetails> {
|
||||||
|
|
||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
private _nameComponent: azdata.InputBoxComponent | undefined;
|
private _nameComponent: azdata.InputBoxComponent | undefined;
|
||||||
@@ -81,9 +81,9 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone
|
|||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): RegisteredModel {
|
public get data(): RegisteredModelDetails {
|
||||||
return {
|
return {
|
||||||
title: this._nameComponent?.value,
|
title: this._nameComponent?.value || '',
|
||||||
description: this._descriptionComponent?.value
|
description: this._descriptionComponent?.value
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ import { ApiWrapper } from '../../common/apiWrapper';
|
|||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
import { IPageView, IDataComponent } from '../interfaces';
|
import { IPageView, IDataComponent } from '../interfaces';
|
||||||
import { ModelDetailsComponent } from './modelDetailsComponent';
|
import { ModelDetailsComponent } from './modelDetailsComponent';
|
||||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
import { RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to pick model details
|
* View to pick model details
|
||||||
*/
|
*/
|
||||||
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModel> {
|
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<RegisteredModelDetails> {
|
||||||
|
|
||||||
private _form: azdata.FormContainer | undefined;
|
private _form: azdata.FormContainer | undefined;
|
||||||
private _formBuilder: azdata.FormBuilder | undefined;
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
@@ -43,7 +43,7 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
|||||||
/**
|
/**
|
||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): RegisteredModel | undefined {
|
public get data(): RegisteredModelDetails | undefined {
|
||||||
return this.modelDetails?.data;
|
return this.modelDetails?.data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,4 +66,13 @@ export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataC
|
|||||||
public get title(): string {
|
public get title(): string {
|
||||||
return constants.modelDetailsPageTitle;
|
return constants.modelDetailsPageTitle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public validate(): Promise<boolean> {
|
||||||
|
if (this.data && this.data.title) {
|
||||||
|
return Promise.resolve(true);
|
||||||
|
} else {
|
||||||
|
this.showErrorMessage(constants.modelNameRequiredError);
|
||||||
|
return Promise.resolve(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,14 +9,23 @@ import { azureResource } from '../../typings/azure-resource';
|
|||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
|
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||||
|
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
|
||||||
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
|
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
|
||||||
import { RegisteredModelsDialog } from './registeredModelsDialog';
|
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
|
||||||
import { AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName, ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName, RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName } from './modelViewBase';
|
import {
|
||||||
|
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||||
|
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
|
||||||
|
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||||
|
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
|
||||||
|
} from './modelViewBase';
|
||||||
import { ControllerBase } from '../controllerBase';
|
import { ControllerBase } from '../controllerBase';
|
||||||
import { RegisterModelWizard } from './registerModelWizard';
|
import { RegisterModelWizard } from './registerModels/registerModelWizard';
|
||||||
import * as fs from 'fs';
|
import * as fs from 'fs';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../common/constants';
|
||||||
|
import { PredictWizard } from './prediction/predictWizard';
|
||||||
|
import { AzureModelResource } from '../interfaces';
|
||||||
|
import { PredictService } from '../../prediction/predictService';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Model management UI controller
|
* Model management UI controller
|
||||||
@@ -30,7 +39,8 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
apiWrapper: ApiWrapper,
|
apiWrapper: ApiWrapper,
|
||||||
private _root: string,
|
private _root: string,
|
||||||
private _amlService: AzureModelRegistryService,
|
private _amlService: AzureModelRegistryService,
|
||||||
private _registeredModelService: RegisteredModelService) {
|
private _registeredModelService: RegisteredModelService,
|
||||||
|
private _predictService: PredictService) {
|
||||||
super(apiWrapper);
|
super(apiWrapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,6 +66,23 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
return view;
|
return view;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Opens the wizard for prediction
|
||||||
|
*/
|
||||||
|
public async predictModel(): Promise<ModelViewBase> {
|
||||||
|
|
||||||
|
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||||
|
|
||||||
|
this.registerEvents(view);
|
||||||
|
|
||||||
|
// Open view
|
||||||
|
//
|
||||||
|
view.open();
|
||||||
|
await view.refresh();
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Register events in the main view
|
* Register events in the main view
|
||||||
* @param view main view
|
* @param view main view
|
||||||
@@ -102,6 +129,28 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
||||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
|
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model, registerArgs.details);
|
||||||
});
|
});
|
||||||
|
view.on(DownloadAzureModelEventName, async (arg) => {
|
||||||
|
let registerArgs = <AzureModelResource>arg;
|
||||||
|
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
|
||||||
|
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
|
||||||
|
});
|
||||||
|
view.on(ListDatabaseNamesEventName, async () => {
|
||||||
|
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
|
||||||
|
});
|
||||||
|
view.on(ListTableNamesEventName, async (arg) => {
|
||||||
|
let dbName = <string>arg;
|
||||||
|
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
|
||||||
|
});
|
||||||
|
view.on(ListColumnNamesEventName, async (arg) => {
|
||||||
|
let tableColumnsArgs = <DatabaseTable>arg;
|
||||||
|
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
|
||||||
|
tableColumnsArgs);
|
||||||
|
});
|
||||||
|
view.on(PredictModelEventName, async (arg) => {
|
||||||
|
let predictArgs = <PredictModelEventArgs>arg;
|
||||||
|
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
|
||||||
|
predictArgs, predictArgs.model, predictArgs.filePath);
|
||||||
|
});
|
||||||
view.on(SourceModelSelectedEventName, () => {
|
view.on(SourceModelSelectedEventName, () => {
|
||||||
view.refresh();
|
view.refresh();
|
||||||
});
|
});
|
||||||
@@ -158,7 +207,7 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
||||||
}
|
}
|
||||||
|
|
||||||
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModel | undefined): Promise<void> {
|
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||||
if (filePath) {
|
if (filePath) {
|
||||||
await service.registerLocalModel(filePath, details);
|
await service.registerLocalModel(filePath, details);
|
||||||
} else {
|
} else {
|
||||||
@@ -175,7 +224,7 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
resourceGroup: azureResource.AzureResource | undefined,
|
resourceGroup: azureResource.AzureResource | undefined,
|
||||||
workspace: Workspace | undefined,
|
workspace: Workspace | undefined,
|
||||||
model: WorkspaceModel | undefined,
|
model: WorkspaceModel | undefined,
|
||||||
details: RegisteredModel | undefined): Promise<void> {
|
details: RegisteredModelDetails | undefined): Promise<void> {
|
||||||
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
|
if (!account || !subscription || !resourceGroup || !workspace || !model || !details) {
|
||||||
throw Error(constants.invalidAzureResourceError);
|
throw Error(constants.invalidAzureResourceError);
|
||||||
}
|
}
|
||||||
@@ -188,4 +237,47 @@ export class ModelManagementController extends ControllerBase {
|
|||||||
throw Error(constants.invalidModelToRegisterError);
|
throw Error(constants.invalidModelToRegisterError);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
||||||
|
return await predictService.getDatabaseList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
|
||||||
|
return await predictService.getTableList(databaseName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
|
||||||
|
return await predictService.getTableColumnsList(databaseTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async generatePredictScript(
|
||||||
|
predictService: PredictService,
|
||||||
|
predictParams: PredictParameters,
|
||||||
|
registeredModel: RegisteredModel | undefined,
|
||||||
|
filePath: string | undefined
|
||||||
|
): Promise<string> {
|
||||||
|
if (!predictParams) {
|
||||||
|
throw Error(constants.invalidModelToPredictError);
|
||||||
|
}
|
||||||
|
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async downloadAzureModel(
|
||||||
|
azureService: AzureModelRegistryService,
|
||||||
|
account: azdata.Account | undefined,
|
||||||
|
subscription: azureResource.AzureResourceSubscription | undefined,
|
||||||
|
resourceGroup: azureResource.AzureResource | undefined,
|
||||||
|
workspace: Workspace | undefined,
|
||||||
|
model: WorkspaceModel | undefined): Promise<string> {
|
||||||
|
if (!account || !subscription || !resourceGroup || !workspace || !model) {
|
||||||
|
throw Error(constants.invalidAzureResourceError);
|
||||||
|
}
|
||||||
|
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
|
||||||
|
if (filePath) {
|
||||||
|
return filePath;
|
||||||
|
} else {
|
||||||
|
throw Error(constants.invalidModelToRegisterError);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import { IPageView, IDataComponent } from '../interfaces';
|
|||||||
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
|
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
|
||||||
import { LocalModelsComponent } from './localModelsComponent';
|
import { LocalModelsComponent } from './localModelsComponent';
|
||||||
import { AzureModelsComponent } from './azureModelsComponent';
|
import { AzureModelsComponent } from './azureModelsComponent';
|
||||||
|
import { CurrentModelsTable } from './registerModels/currentModelsTable';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to pick model source
|
* View to pick model source
|
||||||
@@ -22,8 +23,9 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
public modelResources: ModelSourcesComponent | undefined;
|
public modelResources: ModelSourcesComponent | undefined;
|
||||||
public localModelsComponent: LocalModelsComponent | undefined;
|
public localModelsComponent: LocalModelsComponent | undefined;
|
||||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||||
|
public registeredModelsComponent: CurrentModelsTable | undefined;
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,13 +36,15 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
|
||||||
this._formBuilder = modelBuilder.formContainer();
|
this._formBuilder = modelBuilder.formContainer();
|
||||||
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this);
|
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
|
||||||
this.modelResources.registerComponent(modelBuilder);
|
this.modelResources.registerComponent(modelBuilder);
|
||||||
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
|
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this);
|
||||||
this.localModelsComponent.registerComponent(modelBuilder);
|
this.localModelsComponent.registerComponent(modelBuilder);
|
||||||
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
|
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this);
|
||||||
this.azureModelsComponent.registerComponent(modelBuilder);
|
this.azureModelsComponent.registerComponent(modelBuilder);
|
||||||
this.modelResources.addComponents(this._formBuilder);
|
this.modelResources.addComponents(this._formBuilder);
|
||||||
|
this.registeredModelsComponent = new CurrentModelsTable(this._apiWrapper, this);
|
||||||
|
this.registeredModelsComponent.registerComponent(modelBuilder);
|
||||||
this.refresh();
|
this.refresh();
|
||||||
this._form = this._formBuilder.component();
|
this._form = this._formBuilder.component();
|
||||||
return this._form;
|
return this._form;
|
||||||
@@ -66,19 +70,29 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
public async refresh(): Promise<void> {
|
public async refresh(): Promise<void> {
|
||||||
if (this._formBuilder) {
|
if (this._formBuilder) {
|
||||||
if (this.modelResources && this.modelResources.data === ModelSourceType.Local) {
|
if (this.modelResources && this.modelResources.data === ModelSourceType.Local) {
|
||||||
if (this.localModelsComponent && this.azureModelsComponent) {
|
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||||
this.localModelsComponent.addComponents(this._formBuilder);
|
this.localModelsComponent.addComponents(this._formBuilder);
|
||||||
await this.localModelsComponent.refresh();
|
await this.localModelsComponent.refresh();
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) {
|
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure) {
|
||||||
if (this.localModelsComponent && this.azureModelsComponent) {
|
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||||
this.azureModelsComponent.addComponents(this._formBuilder);
|
this.azureModelsComponent.addComponents(this._formBuilder);
|
||||||
|
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||||
await this.azureModelsComponent.refresh();
|
await this.azureModelsComponent.refresh();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels) {
|
||||||
|
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||||
|
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||||
|
this.registeredModelsComponent.addComponents(this._formBuilder);
|
||||||
|
await this.registeredModelsComponent.refresh();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,4 +103,21 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
|
|||||||
public get title(): string {
|
public get title(): string {
|
||||||
return constants.modelSourcePageTitle;
|
return constants.modelSourcePageTitle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public validate(): Promise<boolean> {
|
||||||
|
let validated = false;
|
||||||
|
if (this.modelResources && this.modelResources.data === ModelSourceType.Local && this.localModelsComponent) {
|
||||||
|
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
|
||||||
|
|
||||||
|
} else if (this.modelResources && this.modelResources.data === ModelSourceType.Azure && this.azureModelsComponent) {
|
||||||
|
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.model !== undefined;
|
||||||
|
|
||||||
|
} else if (this.modelResources && this.modelResources.data === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
|
||||||
|
validated = this.registeredModelsComponent.data !== undefined;
|
||||||
|
}
|
||||||
|
if (!validated) {
|
||||||
|
this.showErrorMessage(constants.invalidModelToSelectError);
|
||||||
|
}
|
||||||
|
return Promise.resolve(validated);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ import { IDataComponent } from '../interfaces';
|
|||||||
|
|
||||||
export enum ModelSourceType {
|
export enum ModelSourceType {
|
||||||
Local,
|
Local,
|
||||||
Azure
|
Azure,
|
||||||
|
RegisteredModels
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* View to pick model source
|
* View to pick model source
|
||||||
@@ -22,9 +23,10 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
private _flexContainer: azdata.FlexContainer | undefined;
|
private _flexContainer: azdata.FlexContainer | undefined;
|
||||||
private _amlModel: azdata.RadioButtonComponent | undefined;
|
private _amlModel: azdata.RadioButtonComponent | undefined;
|
||||||
private _localModel: azdata.RadioButtonComponent | undefined;
|
private _localModel: azdata.RadioButtonComponent | undefined;
|
||||||
private _isLocalModel: boolean = true;
|
private _registeredModels: azdata.RadioButtonComponent | undefined;
|
||||||
|
private _sourceType: ModelSourceType = ModelSourceType.Local;
|
||||||
|
|
||||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +40,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
value: 'local',
|
value: 'local',
|
||||||
name: 'modelLocation',
|
name: 'modelLocation',
|
||||||
label: constants.localModelSource,
|
label: constants.localModelSource,
|
||||||
checked: true
|
checked: this._options[0] === ModelSourceType.Local
|
||||||
}).component();
|
}).component();
|
||||||
|
|
||||||
|
|
||||||
@@ -47,26 +49,58 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
value: 'aml',
|
value: 'aml',
|
||||||
name: 'modelLocation',
|
name: 'modelLocation',
|
||||||
label: constants.azureModelSource,
|
label: constants.azureModelSource,
|
||||||
|
checked: this._options[0] === ModelSourceType.Azure
|
||||||
|
}).component();
|
||||||
|
|
||||||
|
this._registeredModels = modelBuilder.radioButton()
|
||||||
|
.withProperties({
|
||||||
|
value: 'registered',
|
||||||
|
name: 'modelLocation',
|
||||||
|
label: constants.registeredModelsSource,
|
||||||
|
checked: this._options[0] === ModelSourceType.RegisteredModels
|
||||||
}).component();
|
}).component();
|
||||||
|
|
||||||
this._localModel.onDidClick(() => {
|
this._localModel.onDidClick(() => {
|
||||||
this._isLocalModel = true;
|
this._sourceType = ModelSourceType.Local;
|
||||||
this.sendRequest(SourceModelSelectedEventName);
|
this.sendRequest(SourceModelSelectedEventName);
|
||||||
|
|
||||||
});
|
});
|
||||||
this._amlModel.onDidClick(() => {
|
this._amlModel.onDidClick(() => {
|
||||||
this._isLocalModel = false;
|
this._sourceType = ModelSourceType.Azure;
|
||||||
this.sendRequest(SourceModelSelectedEventName);
|
this.sendRequest(SourceModelSelectedEventName);
|
||||||
});
|
});
|
||||||
|
this._registeredModels.onDidClick(() => {
|
||||||
|
this._sourceType = ModelSourceType.RegisteredModels;
|
||||||
|
this.sendRequest(SourceModelSelectedEventName);
|
||||||
|
});
|
||||||
|
let components: azdata.RadioButtonComponent[] = [];
|
||||||
|
|
||||||
|
this._options.forEach(option => {
|
||||||
|
switch (option) {
|
||||||
|
case ModelSourceType.Local:
|
||||||
|
if (this._localModel) {
|
||||||
|
components.push(this._localModel);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case ModelSourceType.Azure:
|
||||||
|
if (this._amlModel) {
|
||||||
|
components.push(this._amlModel);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case ModelSourceType.RegisteredModels:
|
||||||
|
if (this._registeredModels) {
|
||||||
|
components.push(this._registeredModels);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
this._sourceType = this._options[0];
|
||||||
|
|
||||||
this._flexContainer = modelBuilder.flexContainer()
|
this._flexContainer = modelBuilder.flexContainer()
|
||||||
.withLayout({
|
.withLayout({
|
||||||
flexFlow: 'column',
|
flexFlow: 'column',
|
||||||
justifyContent: 'space-between'
|
justifyContent: 'space-between'
|
||||||
}).withItems([
|
}).withItems(components).component();
|
||||||
this._localModel, this._amlModel]
|
|
||||||
).component();
|
|
||||||
|
|
||||||
this._form = modelBuilder.formContainer().withFormItems([{
|
this._form = modelBuilder.formContainer().withFormItems([{
|
||||||
title: '',
|
title: '',
|
||||||
@@ -92,7 +126,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone
|
|||||||
* Returns selected data
|
* Returns selected data
|
||||||
*/
|
*/
|
||||||
public get data(): ModelSourceType {
|
public get data(): ModelSourceType {
|
||||||
return this._isLocalModel ? ModelSourceType.Local : ModelSourceType.Azure;
|
return this._sourceType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ import * as azdata from 'azdata';
|
|||||||
import { azureResource } from '../../typings/azure-resource';
|
import { azureResource } from '../../typings/azure-resource';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../common/apiWrapper';
|
||||||
import { ViewBase } from '../viewBase';
|
import { ViewBase } from '../viewBase';
|
||||||
import { RegisteredModel, WorkspaceModel } from '../../modelManagement/interfaces';
|
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
|
||||||
|
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
|
||||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||||
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
||||||
|
|
||||||
@@ -16,13 +17,18 @@ export interface AzureResourceEventArgs extends AzureWorkspaceResource {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
|
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
|
||||||
details?: RegisteredModel
|
details?: RegisteredModelDetails
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs {
|
export interface RegisterAzureModelEventArgs extends AzureModelResource, RegisterModelEventArgs {
|
||||||
model?: WorkspaceModel;
|
model?: WorkspaceModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface PredictModelEventArgs extends PredictParameters {
|
||||||
|
model?: RegisteredModel;
|
||||||
|
filePath?: string;
|
||||||
|
}
|
||||||
|
|
||||||
export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
|
export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
|
||||||
filePath?: string;
|
filePath?: string;
|
||||||
}
|
}
|
||||||
@@ -32,11 +38,16 @@ export interface RegisterLocalModelEventArgs extends RegisterModelEventArgs {
|
|||||||
export const ListModelsEventName = 'listModels';
|
export const ListModelsEventName = 'listModels';
|
||||||
export const ListAzureModelsEventName = 'listAzureModels';
|
export const ListAzureModelsEventName = 'listAzureModels';
|
||||||
export const ListAccountsEventName = 'listAccounts';
|
export const ListAccountsEventName = 'listAccounts';
|
||||||
|
export const ListDatabaseNamesEventName = 'listDatabaseNames';
|
||||||
|
export const ListTableNamesEventName = 'listTableNames';
|
||||||
|
export const ListColumnNamesEventName = 'listColumnNames';
|
||||||
export const ListSubscriptionsEventName = 'listSubscriptions';
|
export const ListSubscriptionsEventName = 'listSubscriptions';
|
||||||
export const ListGroupsEventName = 'listGroups';
|
export const ListGroupsEventName = 'listGroups';
|
||||||
export const ListWorkspacesEventName = 'listWorkspaces';
|
export const ListWorkspacesEventName = 'listWorkspaces';
|
||||||
export const RegisterLocalModelEventName = 'registerLocalModel';
|
export const RegisterLocalModelEventName = 'registerLocalModel';
|
||||||
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
|
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
|
||||||
|
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
|
||||||
|
export const PredictModelEventName = 'predictModel';
|
||||||
export const RegisterModelEventName = 'registerModel';
|
export const RegisterModelEventName = 'registerModel';
|
||||||
export const SourceModelSelectedEventName = 'sourceModelSelected';
|
export const SourceModelSelectedEventName = 'sourceModelSelected';
|
||||||
|
|
||||||
@@ -59,7 +70,12 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
RegisterLocalModelEventName,
|
RegisterLocalModelEventName,
|
||||||
RegisterAzureModelEventName,
|
RegisterAzureModelEventName,
|
||||||
RegisterModelEventName,
|
RegisterModelEventName,
|
||||||
SourceModelSelectedEventName]);
|
SourceModelSelectedEventName,
|
||||||
|
ListDatabaseNamesEventName,
|
||||||
|
ListTableNamesEventName,
|
||||||
|
ListColumnNamesEventName,
|
||||||
|
PredictModelEventName,
|
||||||
|
DownloadAzureModelEventName]);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -91,6 +107,27 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
return await this.sendDataRequest(ListAccountsEventName);
|
return await this.sendDataRequest(ListAccountsEventName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* lists database names
|
||||||
|
*/
|
||||||
|
public async listDatabaseNames(): Promise<string[]> {
|
||||||
|
return await this.sendDataRequest(ListDatabaseNamesEventName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* lists table names
|
||||||
|
*/
|
||||||
|
public async listTableNames(dbName: string): Promise<DatabaseTable[]> {
|
||||||
|
return await this.sendDataRequest(ListTableNamesEventName, dbName);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* lists column names
|
||||||
|
*/
|
||||||
|
public async listColumnNames(table: DatabaseTable): Promise<string[]> {
|
||||||
|
return await this.sendDataRequest(ListColumnNamesEventName, table);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* lists azure subscriptions
|
* lists azure subscriptions
|
||||||
* @param account azure account
|
* @param account azure account
|
||||||
@@ -106,7 +143,7 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
* registers local model
|
* registers local model
|
||||||
* @param localFilePath local file path
|
* @param localFilePath local file path
|
||||||
*/
|
*/
|
||||||
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModel | undefined): Promise<void> {
|
public async registerLocalModel(localFilePath: string | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||||
const args: RegisterLocalModelEventArgs = {
|
const args: RegisterLocalModelEventArgs = {
|
||||||
filePath: localFilePath,
|
filePath: localFilePath,
|
||||||
details: details
|
details: details
|
||||||
@@ -114,17 +151,38 @@ export abstract class ModelViewBase extends ViewBase {
|
|||||||
return await this.sendDataRequest(RegisterLocalModelEventName, args);
|
return await this.sendDataRequest(RegisterLocalModelEventName, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* download azure model
|
||||||
|
* @param args azure resource
|
||||||
|
*/
|
||||||
|
public async downloadAzureModel(resource: AzureModelResource | undefined): Promise<string> {
|
||||||
|
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* registers azure model
|
* registers azure model
|
||||||
* @param args azure resource
|
* @param args azure resource
|
||||||
*/
|
*/
|
||||||
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModel | undefined): Promise<void> {
|
public async registerAzureModel(resource: AzureModelResource | undefined, details: RegisteredModelDetails | undefined): Promise<void> {
|
||||||
const args: RegisterAzureModelEventArgs = Object.assign({}, resource, {
|
const args: RegisterAzureModelEventArgs = Object.assign({}, resource, {
|
||||||
details: details
|
details: details
|
||||||
});
|
});
|
||||||
return await this.sendDataRequest(RegisterAzureModelEventName, args);
|
return await this.sendDataRequest(RegisterAzureModelEventName, args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* registers azure model
|
||||||
|
* @param args azure resource
|
||||||
|
*/
|
||||||
|
public async generatePredictScript(model: RegisteredModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> {
|
||||||
|
const args: PredictModelEventArgs = Object.assign({}, params, {
|
||||||
|
model: model,
|
||||||
|
filePath: filePath,
|
||||||
|
loadFromRegisteredModel: !filePath
|
||||||
|
});
|
||||||
|
return await this.sendDataRequest(PredictModelEventName, args);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* list resource groups
|
* list resource groups
|
||||||
* @param account azure account
|
* @param account azure account
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
|
import * as constants from '../../../common/constants';
|
||||||
|
import { IDataComponent } from '../../interfaces';
|
||||||
|
import { ColumnsTable } from './columnsTable';
|
||||||
|
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View to render filters to pick an azure resource
|
||||||
|
*/
|
||||||
|
export class ColumnsFilterComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
|
||||||
|
|
||||||
|
private _form: azdata.FormContainer | undefined;
|
||||||
|
private _databases: azdata.DropDownComponent | undefined;
|
||||||
|
private _tables: azdata.DropDownComponent | undefined;
|
||||||
|
private _columns: ColumnsTable | undefined;
|
||||||
|
private _dbNames: string[] = [];
|
||||||
|
private _tableNames: DatabaseTable[] = [];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new view
|
||||||
|
*/
|
||||||
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||||
|
super(apiWrapper, parent.root, parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register components
|
||||||
|
* @param modelBuilder model builder
|
||||||
|
*/
|
||||||
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
this._databases = modelBuilder.dropDown().withProperties({
|
||||||
|
width: this.componentMaxLength
|
||||||
|
}).component();
|
||||||
|
this._tables = modelBuilder.dropDown().withProperties({
|
||||||
|
width: this.componentMaxLength
|
||||||
|
}).component();
|
||||||
|
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
|
||||||
|
|
||||||
|
this._databases.onValueChanged(async () => {
|
||||||
|
await this.onDatabaseSelected();
|
||||||
|
});
|
||||||
|
|
||||||
|
this._tables.onValueChanged(async () => {
|
||||||
|
await this.onTableSelected();
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
this._form = modelBuilder.formContainer().withFormItems([{
|
||||||
|
title: constants.azureAccount,
|
||||||
|
component: this._databases
|
||||||
|
}, {
|
||||||
|
title: constants.azureSubscription,
|
||||||
|
component: this._tables
|
||||||
|
}, {
|
||||||
|
title: constants.azureGroup,
|
||||||
|
component: this._columns.component
|
||||||
|
}]).component();
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._databases && this._tables && this._columns) {
|
||||||
|
formBuilder.addFormItems([{
|
||||||
|
title: constants.columnDatabase,
|
||||||
|
component: this._databases
|
||||||
|
}, {
|
||||||
|
title: constants.columnTable,
|
||||||
|
component: this._tables
|
||||||
|
}, {
|
||||||
|
title: constants.inputColumns,
|
||||||
|
component: this._columns.component
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._databases && this._tables && this._columns) {
|
||||||
|
formBuilder.removeFormItem({
|
||||||
|
title: constants.azureAccount,
|
||||||
|
component: this._databases
|
||||||
|
});
|
||||||
|
formBuilder.removeFormItem({
|
||||||
|
title: constants.azureSubscription,
|
||||||
|
component: this._tables
|
||||||
|
});
|
||||||
|
formBuilder.removeFormItem({
|
||||||
|
title: constants.azureGroup,
|
||||||
|
component: this._columns.component
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the created component
|
||||||
|
*/
|
||||||
|
public get component(): azdata.Component | undefined {
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): PredictInputParameters | undefined {
|
||||||
|
return Object.assign({}, this.databaseTable, {
|
||||||
|
inputColumns: this.columnNames
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* loads data in the components
|
||||||
|
*/
|
||||||
|
public async loadData(): Promise<void> {
|
||||||
|
this._dbNames = await this.listDatabaseNames();
|
||||||
|
if (this._databases && this._dbNames && this._dbNames.length > 0) {
|
||||||
|
this._databases.values = this._dbNames;
|
||||||
|
this._databases.value = this._dbNames[0];
|
||||||
|
}
|
||||||
|
await this.onDatabaseSelected();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* refreshes the view
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
await this.loadData();
|
||||||
|
}
|
||||||
|
|
||||||
|
private async onDatabaseSelected(): Promise<void> {
|
||||||
|
this._tableNames = await this.listTableNames(this.databaseName || '');
|
||||||
|
if (this._tables && this._tableNames && this._tableNames.length > 0) {
|
||||||
|
this._tables.values = this._tableNames.map(t => this.getTableFullName(t));
|
||||||
|
this._tables.value = this.getTableFullName(this._tableNames[0]);
|
||||||
|
}
|
||||||
|
await this.onTableSelected();
|
||||||
|
}
|
||||||
|
|
||||||
|
private getTableFullName(table: DatabaseTable): string {
|
||||||
|
return `${table.schema}.${table.tableName}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async onTableSelected(): Promise<void> {
|
||||||
|
this._columns?.loadData(this.databaseTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
private get databaseName(): string | undefined {
|
||||||
|
return <string>this._databases?.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
private get databaseTable(): DatabaseTable {
|
||||||
|
let selectedItem = this._tableNames.find(x => this.getTableFullName(x) === this._tables?.value);
|
||||||
|
return {
|
||||||
|
databaseName: this.databaseName,
|
||||||
|
tableName: selectedItem?.tableName,
|
||||||
|
schema: selectedItem?.schema
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private get columnNames(): PredictColumn[] | undefined {
|
||||||
|
return this._columns?.data;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
|
import * as constants from '../../../common/constants';
|
||||||
|
import { IPageView, IDataComponent } from '../../interfaces';
|
||||||
|
import { ColumnsFilterComponent } from './columnsFilterComponent';
|
||||||
|
import { OutputColumnsComponent } from './outputColumnsComponent';
|
||||||
|
import { PredictParameters } from '../../../prediction/interfaces';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View to pick model source
|
||||||
|
*/
|
||||||
|
export class ColumnsSelectionPage extends ModelViewBase implements IPageView, IDataComponent<PredictParameters> {
|
||||||
|
|
||||||
|
private _form: azdata.FormContainer | undefined;
|
||||||
|
private _formBuilder: azdata.FormBuilder | undefined;
|
||||||
|
public columnsFilterComponent: ColumnsFilterComponent | undefined;
|
||||||
|
public outputColumnsComponent: OutputColumnsComponent | undefined;
|
||||||
|
|
||||||
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||||
|
super(apiWrapper, parent.root, parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param modelBuilder Register components
|
||||||
|
*/
|
||||||
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
this._formBuilder = modelBuilder.formContainer();
|
||||||
|
this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this);
|
||||||
|
this.columnsFilterComponent.registerComponent(modelBuilder);
|
||||||
|
this.columnsFilterComponent.addComponents(this._formBuilder);
|
||||||
|
this.refresh();
|
||||||
|
|
||||||
|
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
|
||||||
|
this.outputColumnsComponent.registerComponent(modelBuilder);
|
||||||
|
this.outputColumnsComponent.addComponents(this._formBuilder);
|
||||||
|
this.refresh();
|
||||||
|
this._form = this._formBuilder.component();
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): PredictParameters | undefined {
|
||||||
|
return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ?
|
||||||
|
Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
|
||||||
|
undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the component
|
||||||
|
*/
|
||||||
|
public get component(): azdata.Component | undefined {
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refreshes the view
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
if (this._formBuilder) {
|
||||||
|
if (this.columnsFilterComponent) {
|
||||||
|
await this.columnsFilterComponent.refresh();
|
||||||
|
}
|
||||||
|
if (this.outputColumnsComponent) {
|
||||||
|
await this.outputColumnsComponent.refresh();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns page title
|
||||||
|
*/
|
||||||
|
public get title(): string {
|
||||||
|
return constants.columnSelectionPageTitle;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import * as constants from '../../../common/constants';
|
||||||
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
|
import { IDataComponent } from '../../interfaces';
|
||||||
|
import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View to render azure models in a table
|
||||||
|
*/
|
||||||
|
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
|
||||||
|
|
||||||
|
private _table: azdata.DeclarativeTableComponent;
|
||||||
|
private _selectedColumns: PredictColumn[] = [];
|
||||||
|
private _columns: string[] | undefined;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a view to render azure models in a table
|
||||||
|
*/
|
||||||
|
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||||
|
super(apiWrapper, parent.root, parent);
|
||||||
|
this._table = this.registerComponent(this._modelBuilder);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register components
|
||||||
|
* @param modelBuilder model builder
|
||||||
|
*/
|
||||||
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
||||||
|
this._table = modelBuilder.declarativeTable()
|
||||||
|
.withProperties<azdata.DeclarativeTableProperties>(
|
||||||
|
{
|
||||||
|
columns: [
|
||||||
|
{ // Name
|
||||||
|
displayName: constants.columnDatabase,
|
||||||
|
ariaLabel: constants.columnName,
|
||||||
|
valueType: azdata.DeclarativeDataType.string,
|
||||||
|
isReadOnly: true,
|
||||||
|
width: 120,
|
||||||
|
headerCssStyles: {
|
||||||
|
...constants.cssStyles.tableHeader
|
||||||
|
},
|
||||||
|
rowCssStyles: {
|
||||||
|
...constants.cssStyles.tableRow
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ // Action
|
||||||
|
displayName: constants.inputName,
|
||||||
|
ariaLabel: constants.inputName,
|
||||||
|
valueType: azdata.DeclarativeDataType.component,
|
||||||
|
isReadOnly: true,
|
||||||
|
width: 50,
|
||||||
|
headerCssStyles: {
|
||||||
|
...constants.cssStyles.tableHeader
|
||||||
|
},
|
||||||
|
rowCssStyles: {
|
||||||
|
...constants.cssStyles.tableRow
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ // Action
|
||||||
|
displayName: '',
|
||||||
|
valueType: azdata.DeclarativeDataType.component,
|
||||||
|
isReadOnly: true,
|
||||||
|
width: 50,
|
||||||
|
headerCssStyles: {
|
||||||
|
...constants.cssStyles.tableHeader
|
||||||
|
},
|
||||||
|
rowCssStyles: {
|
||||||
|
...constants.cssStyles.tableRow
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
data: [],
|
||||||
|
ariaLabel: constants.mlsConfigTitle
|
||||||
|
})
|
||||||
|
.component();
|
||||||
|
return this._table;
|
||||||
|
}
|
||||||
|
|
||||||
|
public get component(): azdata.DeclarativeTableComponent {
|
||||||
|
return this._table;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load data in the component
|
||||||
|
* @param workspaceResource Azure workspace
|
||||||
|
*/
|
||||||
|
public async loadData(table: DatabaseTable): Promise<void> {
|
||||||
|
this._selectedColumns = [];
|
||||||
|
if (this._table) {
|
||||||
|
this._columns = await this.listColumnNames(table);
|
||||||
|
let tableData: any[][] = [];
|
||||||
|
|
||||||
|
if (this._columns) {
|
||||||
|
tableData = tableData.concat(this._columns.map(model => this.createTableRow(model)));
|
||||||
|
}
|
||||||
|
|
||||||
|
this._table.data = tableData;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private createTableRow(column: string): any[] {
|
||||||
|
if (this._modelBuilder) {
|
||||||
|
let selectRowButton = this._modelBuilder.checkBox().withProperties({
|
||||||
|
|
||||||
|
width: 15,
|
||||||
|
height: 15,
|
||||||
|
checked: true
|
||||||
|
}).component();
|
||||||
|
let nameInputBox = this._modelBuilder.inputBox().withProperties({
|
||||||
|
value: '',
|
||||||
|
width: 150
|
||||||
|
}).component();
|
||||||
|
this._selectedColumns.push({ name: column });
|
||||||
|
selectRowButton.onChanged(() => {
|
||||||
|
if (selectRowButton.checked) {
|
||||||
|
if (!this._selectedColumns.find(x => x.name === column)) {
|
||||||
|
this._selectedColumns.push({ name: column });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (this._selectedColumns.find(x => x.name === column)) {
|
||||||
|
this._selectedColumns = this._selectedColumns.filter(x => x.name !== column);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
nameInputBox.onTextChanged(() => {
|
||||||
|
let selectedRow = this._selectedColumns.find(x => x.name === column);
|
||||||
|
if (selectedRow) {
|
||||||
|
selectedRow.displayName = nameInputBox.value;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return [column, nameInputBox, selectRowButton];
|
||||||
|
}
|
||||||
|
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): PredictColumn[] | undefined {
|
||||||
|
return this._selectedColumns;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refreshes the view
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
|
import * as constants from '../../../common/constants';
|
||||||
|
import { IDataComponent } from '../../interfaces';
|
||||||
|
import { PredictColumn } from '../../../prediction/interfaces';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View to render filters to pick an azure resource
|
||||||
|
*/
|
||||||
|
const componentWidth = 60;
|
||||||
|
export class OutputColumnsComponent extends ModelViewBase implements IDataComponent<PredictColumn[]> {
|
||||||
|
|
||||||
|
private _form: azdata.FormContainer | undefined;
|
||||||
|
private _flex: azdata.FlexContainer | undefined;
|
||||||
|
private _columnName: azdata.InputBoxComponent | undefined;
|
||||||
|
private _columnTypes: azdata.DropDownComponent | undefined;
|
||||||
|
private _dataTypes: string[] = [
|
||||||
|
'int',
|
||||||
|
'nvarchar(MAX)',
|
||||||
|
'varchar(MAX)',
|
||||||
|
'float',
|
||||||
|
'double',
|
||||||
|
'bit'
|
||||||
|
];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new view
|
||||||
|
*/
|
||||||
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||||
|
super(apiWrapper, parent.root, parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register components
|
||||||
|
* @param modelBuilder model builder
|
||||||
|
*/
|
||||||
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
|
this._columnName = modelBuilder.inputBox().withProperties({
|
||||||
|
width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength
|
||||||
|
}).component();
|
||||||
|
this._columnTypes = modelBuilder.dropDown().withProperties({
|
||||||
|
width: componentWidth
|
||||||
|
}).component();
|
||||||
|
|
||||||
|
let flex = modelBuilder.flexContainer()
|
||||||
|
.withLayout({
|
||||||
|
width: this._columnName.width
|
||||||
|
}).withItems([
|
||||||
|
this._columnName]
|
||||||
|
).component();
|
||||||
|
this._flex = modelBuilder.flexContainer()
|
||||||
|
.withLayout({
|
||||||
|
flexFlow: 'row',
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
width: this.componentMaxLength
|
||||||
|
}).withItems([
|
||||||
|
flex, this._columnTypes]
|
||||||
|
).component();
|
||||||
|
|
||||||
|
this._form = modelBuilder.formContainer().withFormItems([{
|
||||||
|
title: constants.azureAccount,
|
||||||
|
component: this._flex
|
||||||
|
}]).component();
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._flex) {
|
||||||
|
formBuilder.addFormItems([{
|
||||||
|
title: constants.outputColumns,
|
||||||
|
component: this._flex
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this._flex) {
|
||||||
|
formBuilder.removeFormItem({
|
||||||
|
title: constants.outputColumns,
|
||||||
|
component: this._flex
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the created component
|
||||||
|
*/
|
||||||
|
public get component(): azdata.Component | undefined {
|
||||||
|
return this._form;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* loads data in the components
|
||||||
|
*/
|
||||||
|
public async loadData(): Promise<void> {
|
||||||
|
if (this._columnTypes) {
|
||||||
|
this._columnTypes.values = this._dataTypes;
|
||||||
|
this._columnTypes.value = this._dataTypes[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* refreshes the view
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
await this.loadData();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): PredictColumn[] | undefined {
|
||||||
|
return this._columnName && this._columnTypes ? [{
|
||||||
|
name: this._columnName.value || '',
|
||||||
|
dataType: <string>this._columnTypes.value || ''
|
||||||
|
}] : undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
/*---------------------------------------------------------------------------------------------
|
||||||
|
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||||
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
import * as azdata from 'azdata';
|
||||||
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
|
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
|
||||||
|
import { LocalModelsComponent } from '../localModelsComponent';
|
||||||
|
import { AzureModelsComponent } from '../azureModelsComponent';
|
||||||
|
import * as constants from '../../../common/constants';
|
||||||
|
import { WizardView } from '../../wizardView';
|
||||||
|
import { ModelSourcePage } from '../modelSourcePage';
|
||||||
|
import { ColumnsSelectionPage } from './columnsSelectionPage';
|
||||||
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wizard to register a model
|
||||||
|
*/
|
||||||
|
export class PredictWizard extends ModelViewBase {
|
||||||
|
|
||||||
|
public modelSourcePage: ModelSourcePage | undefined;
|
||||||
|
//public modelDetailsPage: ModelDetailsPage | undefined;
|
||||||
|
public columnsSelectionPage: ColumnsSelectionPage | undefined;
|
||||||
|
public wizardView: WizardView | undefined;
|
||||||
|
private _parentView: ModelViewBase | undefined;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
apiWrapper: ApiWrapper,
|
||||||
|
root: string,
|
||||||
|
parent?: ModelViewBase) {
|
||||||
|
super(apiWrapper, root);
|
||||||
|
this._parentView = parent;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Opens a dialog to manage packages used by notebooks.
|
||||||
|
*/
|
||||||
|
public open(): void {
|
||||||
|
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
|
||||||
|
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
|
||||||
|
this.wizardView = new WizardView(this._apiWrapper);
|
||||||
|
|
||||||
|
let wizard = this.wizardView.createWizard(constants.makePredictionTitle,
|
||||||
|
[this.modelSourcePage,
|
||||||
|
this.columnsSelectionPage]);
|
||||||
|
|
||||||
|
this.mainViewPanel = wizard;
|
||||||
|
wizard.doneButton.label = constants.predictModel;
|
||||||
|
wizard.generateScriptButton.hidden = true;
|
||||||
|
wizard.displayPageTitles = true;
|
||||||
|
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||||
|
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||||
|
if (validated && pageInfo.newPage === undefined) {
|
||||||
|
wizard.cancelButton.enabled = false;
|
||||||
|
wizard.backButton.enabled = false;
|
||||||
|
await this.predict();
|
||||||
|
wizard.cancelButton.enabled = true;
|
||||||
|
wizard.backButton.enabled = true;
|
||||||
|
if (this._parentView) {
|
||||||
|
this._parentView?.refresh();
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
|
||||||
|
}
|
||||||
|
return validated;
|
||||||
|
});
|
||||||
|
|
||||||
|
wizard.open();
|
||||||
|
}
|
||||||
|
|
||||||
|
public get modelResources(): ModelSourcesComponent | undefined {
|
||||||
|
return this.modelSourcePage?.modelResources;
|
||||||
|
}
|
||||||
|
|
||||||
|
public get localModelsComponent(): LocalModelsComponent | undefined {
|
||||||
|
return this.modelSourcePage?.localModelsComponent;
|
||||||
|
}
|
||||||
|
|
||||||
|
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
||||||
|
return this.modelSourcePage?.azureModelsComponent;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async predict(): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
let modelFilePath: string = '';
|
||||||
|
let registeredModel: RegisteredModel | undefined = undefined;
|
||||||
|
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||||
|
modelFilePath = this.localModelsComponent.data;
|
||||||
|
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
|
||||||
|
modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data);
|
||||||
|
} else {
|
||||||
|
registeredModel = this.modelSourcePage?.registeredModelsComponent?.data;
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data);
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refresh the pages
|
||||||
|
*/
|
||||||
|
public async refresh(): Promise<void> {
|
||||||
|
await this.wizardView?.refresh();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,11 +5,11 @@
|
|||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
|
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../../common/constants';
|
||||||
import { ModelViewBase, RegisterModelEventName } from './modelViewBase';
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
import { CurrentModelsTable } from './currentModelsTable';
|
import { CurrentModelsTable } from './currentModelsTable';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { IPageView } from '../interfaces';
|
import { IPageView } from '../../interfaces';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to render current registered models
|
* View to render current registered models
|
||||||
@@ -33,28 +33,21 @@ export class CurrentModelsPage extends ModelViewBase implements IPageView {
|
|||||||
* @param modelBuilder register the components
|
* @param modelBuilder register the components
|
||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||||
this._dataTable = new CurrentModelsTable(this._apiWrapper, modelBuilder, this);
|
this._dataTable = new CurrentModelsTable(this._apiWrapper, this);
|
||||||
|
this._dataTable.registerComponent(modelBuilder);
|
||||||
this._tableComponent = this._dataTable.component;
|
this._tableComponent = this._dataTable.component;
|
||||||
|
|
||||||
let registerButton = modelBuilder.button().withProperties({
|
let formModelBuilder = modelBuilder.formContainer();
|
||||||
label: constants.registerModelTitle,
|
|
||||||
width: this.buttonMaxLength
|
|
||||||
}).component();
|
|
||||||
registerButton.onDidClick(async () => {
|
|
||||||
await this.sendDataRequest(RegisterModelEventName);
|
|
||||||
});
|
|
||||||
|
|
||||||
let formModel = modelBuilder.formContainer()
|
if (this._tableComponent) {
|
||||||
.withFormItems([{
|
formModelBuilder.addFormItem({
|
||||||
title: '',
|
|
||||||
component: registerButton
|
|
||||||
}, {
|
|
||||||
component: this._tableComponent,
|
component: this._tableComponent,
|
||||||
title: ''
|
title: ''
|
||||||
}]).component();
|
});
|
||||||
|
}
|
||||||
|
|
||||||
this._loader = modelBuilder.loadingComponent()
|
this._loader = modelBuilder.loadingComponent()
|
||||||
.withItem(formModel)
|
.withItem(formModelBuilder.component())
|
||||||
.withProperties({
|
.withProperties({
|
||||||
loading: true
|
loading: true
|
||||||
}).component();
|
}).component();
|
||||||
@@ -4,24 +4,26 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../../common/constants';
|
||||||
import { ModelViewBase } from './modelViewBase';
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { RegisteredModel } from '../../modelManagement/interfaces';
|
import { RegisteredModel } from '../../../modelManagement/interfaces';
|
||||||
|
import { IDataComponent } from '../../interfaces';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* View to render registered models table
|
* View to render registered models table
|
||||||
*/
|
*/
|
||||||
export class CurrentModelsTable extends ModelViewBase {
|
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<RegisteredModel> {
|
||||||
|
|
||||||
private _table: azdata.DeclarativeTableComponent;
|
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||||
|
private _modelBuilder: azdata.ModelBuilder | undefined;
|
||||||
|
private _selectedModel: any;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates new view
|
* Creates new view
|
||||||
*/
|
*/
|
||||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||||
super(apiWrapper, parent.root, parent);
|
super(apiWrapper, parent.root, parent);
|
||||||
this._table = this.registerComponent(this._modelBuilder);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -29,6 +31,7 @@ export class CurrentModelsTable extends ModelViewBase {
|
|||||||
* @param modelBuilder register the components
|
* @param modelBuilder register the components
|
||||||
*/
|
*/
|
||||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
||||||
|
this._modelBuilder = modelBuilder;
|
||||||
this._table = modelBuilder.declarativeTable()
|
this._table = modelBuilder.declarativeTable()
|
||||||
.withProperties<azdata.DeclarativeTableProperties>(
|
.withProperties<azdata.DeclarativeTableProperties>(
|
||||||
{
|
{
|
||||||
@@ -92,10 +95,23 @@ export class CurrentModelsTable extends ModelViewBase {
|
|||||||
return this._table;
|
return this._table;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this.component) {
|
||||||
|
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||||
|
if (this.component) {
|
||||||
|
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the component
|
* Returns the component
|
||||||
*/
|
*/
|
||||||
public get component(): azdata.DeclarativeTableComponent {
|
public get component(): azdata.DeclarativeTableComponent | undefined {
|
||||||
return this._table;
|
return this._table;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +119,7 @@ export class CurrentModelsTable extends ModelViewBase {
|
|||||||
* Loads the data in the component
|
* Loads the data in the component
|
||||||
*/
|
*/
|
||||||
public async loadData(): Promise<void> {
|
public async loadData(): Promise<void> {
|
||||||
|
if (this._table) {
|
||||||
let models: RegisteredModel[] | undefined;
|
let models: RegisteredModel[] | undefined;
|
||||||
|
|
||||||
models = await this.listModels();
|
models = await this.listModels();
|
||||||
@@ -114,27 +131,33 @@ export class CurrentModelsTable extends ModelViewBase {
|
|||||||
|
|
||||||
this._table.data = tableData;
|
this._table.data = tableData;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private createTableRow(model: RegisteredModel): any[] {
|
private createTableRow(model: RegisteredModel): any[] {
|
||||||
if (this._modelBuilder) {
|
if (this._modelBuilder) {
|
||||||
let editLanguageButton = this._modelBuilder.button().withProperties({
|
let selectModelButton = this._modelBuilder.radioButton().withProperties({
|
||||||
label: '',
|
name: 'amlModel',
|
||||||
title: constants.deleteTitle,
|
value: model.id,
|
||||||
iconPath: {
|
|
||||||
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
|
|
||||||
light: this.asAbsolutePath('images/light/edit.svg')
|
|
||||||
},
|
|
||||||
width: 15,
|
width: 15,
|
||||||
height: 15
|
height: 15,
|
||||||
|
checked: false
|
||||||
}).component();
|
}).component();
|
||||||
editLanguageButton.onDidClick(() => {
|
selectModelButton.onDidClick(() => {
|
||||||
|
this._selectedModel = model;
|
||||||
});
|
});
|
||||||
return [model.artifactName, model.title, model.created, editLanguageButton];
|
return [model.artifactName, model.title, model.created, selectModelButton];
|
||||||
}
|
}
|
||||||
|
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns selected data
|
||||||
|
*/
|
||||||
|
public get data(): RegisteredModel | undefined {
|
||||||
|
return this._selectedModel;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Refreshes the view
|
* Refreshes the view
|
||||||
*/
|
*/
|
||||||
@@ -4,15 +4,15 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
import { ModelViewBase } from './modelViewBase';
|
import { ModelViewBase } from '../modelViewBase';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { ModelSourcesComponent, ModelSourceType } from './modelSourcesComponent';
|
import { ModelSourcesComponent, ModelSourceType } from '../modelSourcesComponent';
|
||||||
import { LocalModelsComponent } from './localModelsComponent';
|
import { LocalModelsComponent } from '../localModelsComponent';
|
||||||
import { AzureModelsComponent } from './azureModelsComponent';
|
import { AzureModelsComponent } from '../azureModelsComponent';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../../common/constants';
|
||||||
import { WizardView } from '../wizardView';
|
import { WizardView } from '../../wizardView';
|
||||||
import { ModelSourcePage } from './modelSourcePage';
|
import { ModelSourcePage } from '../modelSourcePage';
|
||||||
import { ModelDetailsPage } from './modelDetailsPage';
|
import { ModelDetailsPage } from '../modelDetailsPage';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wizard to register a model
|
* Wizard to register a model
|
||||||
@@ -47,19 +47,20 @@ export class RegisterModelWizard extends ModelViewBase {
|
|||||||
wizard.generateScriptButton.hidden = true;
|
wizard.generateScriptButton.hidden = true;
|
||||||
wizard.displayPageTitles = true;
|
wizard.displayPageTitles = true;
|
||||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||||
if (pageInfo.newPage === undefined) {
|
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||||
|
if (validated && pageInfo.newPage === undefined) {
|
||||||
wizard.cancelButton.enabled = false;
|
wizard.cancelButton.enabled = false;
|
||||||
wizard.backButton.enabled = false;
|
wizard.backButton.enabled = false;
|
||||||
await this.registerModel();
|
let result = await this.registerModel();
|
||||||
wizard.cancelButton.enabled = true;
|
wizard.cancelButton.enabled = true;
|
||||||
wizard.backButton.enabled = true;
|
wizard.backButton.enabled = true;
|
||||||
if (this._parentView) {
|
if (this._parentView) {
|
||||||
this._parentView?.refresh();
|
await this._parentView?.refresh();
|
||||||
}
|
}
|
||||||
return true;
|
return result;
|
||||||
|
|
||||||
}
|
}
|
||||||
return true;
|
return validated;
|
||||||
});
|
});
|
||||||
|
|
||||||
wizard.open();
|
wizard.open();
|
||||||
@@ -5,10 +5,10 @@
|
|||||||
|
|
||||||
import { CurrentModelsPage } from './currentModelsPage';
|
import { CurrentModelsPage } from './currentModelsPage';
|
||||||
|
|
||||||
import { ModelViewBase } from './modelViewBase';
|
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
|
||||||
import * as constants from '../../common/constants';
|
import * as constants from '../../../common/constants';
|
||||||
import { ApiWrapper } from '../../common/apiWrapper';
|
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||||
import { DialogView } from '../dialogView';
|
import { DialogView } from '../../dialogView';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dialog to render registered model views
|
* Dialog to render registered model views
|
||||||
@@ -31,7 +31,13 @@ export class RegisteredModelsDialog extends ModelViewBase {
|
|||||||
|
|
||||||
this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this);
|
this.currentLanguagesTab = new CurrentModelsPage(this._apiWrapper, this);
|
||||||
|
|
||||||
|
let registerModelButton = this._apiWrapper.createButton(constants.registerModelTitle);
|
||||||
|
registerModelButton.onClick(async () => {
|
||||||
|
await this.sendDataRequest(RegisterModelEventName);
|
||||||
|
});
|
||||||
|
|
||||||
let dialog = this.dialogView.createDialog('', [this.currentLanguagesTab]);
|
let dialog = this.dialogView.createDialog('', [this.currentLanguagesTab]);
|
||||||
|
dialog.customButtons = [registerModelButton];
|
||||||
this.mainViewPanel = dialog;
|
this.mainViewPanel = dialog;
|
||||||
dialog.okButton.hidden = true;
|
dialog.okButton.hidden = true;
|
||||||
dialog.cancelButton.label = constants.extLangDoneButtonText;
|
dialog.cancelButton.label = constants.extLangDoneButtonText;
|
||||||
@@ -4,6 +4,8 @@
|
|||||||
*--------------------------------------------------------------------------------------------*/
|
*--------------------------------------------------------------------------------------------*/
|
||||||
|
|
||||||
import * as azdata from 'azdata';
|
import * as azdata from 'azdata';
|
||||||
|
import * as vscode from 'vscode';
|
||||||
|
|
||||||
import * as constants from '../common/constants';
|
import * as constants from '../common/constants';
|
||||||
import { ApiWrapper } from '../common/apiWrapper';
|
import { ApiWrapper } from '../common/apiWrapper';
|
||||||
import * as path from 'path';
|
import * as path from 'path';
|
||||||
@@ -21,8 +23,7 @@ export interface CallbackEventArgs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const CallEventNamePostfix = 'Callback';
|
export const CallEventNamePostfix = 'Callback';
|
||||||
export const LocalFileEventName = 'localFile';
|
export const LocalPathsEventName = 'localPaths';
|
||||||
export const LocalFolderEventName = 'localFolder';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for views
|
* Base class for views
|
||||||
@@ -51,7 +52,7 @@ export abstract class ViewBase extends EventEmitterCollection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected getEventNames(): string[] {
|
protected getEventNames(): string[] {
|
||||||
return [LocalFolderEventName, LocalFileEventName];
|
return [LocalPathsEventName];
|
||||||
}
|
}
|
||||||
|
|
||||||
protected getCallbackEventNames(): string[] {
|
protected getCallbackEventNames(): string[] {
|
||||||
@@ -118,12 +119,8 @@ export abstract class ViewBase extends EventEmitterCollection {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getLocalFilePath(): Promise<string> {
|
public async getLocalPaths(options: vscode.OpenDialogOptions): Promise<string[]> {
|
||||||
return await this.sendDataRequest(LocalFileEventName);
|
return await this.sendDataRequest(LocalPathsEventName, options);
|
||||||
}
|
|
||||||
|
|
||||||
public async getLocalFolderPath(): Promise<string> {
|
|
||||||
return await this.sendDataRequest(LocalFolderEventName);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getLocationTitle(): Promise<string> {
|
public async getLocationTitle(): Promise<string> {
|
||||||
@@ -174,12 +171,12 @@ export abstract class ViewBase extends EventEmitterCollection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public showErrorMessage(message: string, error?: any): void {
|
public showErrorMessage(message: string, error?: any): void {
|
||||||
this.showMessage(`${message} ${constants.getErrorMessage(error)}`, azdata.window.MessageLevel.Error);
|
this.showMessage(`${message} ${error ? constants.getErrorMessage(error) : ''}`, azdata.window.MessageLevel.Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
private showMessage(message: string, level: azdata.window.MessageLevel): void {
|
private showMessage(message: string, level: azdata.window.MessageLevel): void {
|
||||||
if (this._mainViewPanel) {
|
if (this.mainViewPanel) {
|
||||||
this._mainViewPanel.message = {
|
this.mainViewPanel.message = {
|
||||||
text: message,
|
text: message,
|
||||||
level: level
|
level: level
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -45,6 +45,19 @@ export class WizardView extends MainViewBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds wizard page
|
||||||
|
* @param page page
|
||||||
|
* @param index page index
|
||||||
|
*/
|
||||||
|
public removeWizardPage(page: IPageView, index: number): void {
|
||||||
|
if (this._wizard && this._pages[index] === page) {
|
||||||
|
this._pages = this._pages.splice(index);
|
||||||
|
this._wizard.removePage(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param title Creates anew wizard
|
* @param title Creates anew wizard
|
||||||
@@ -57,9 +70,21 @@ export class WizardView extends MainViewBase {
|
|||||||
this._wizard.onPageChanged(async (info) => {
|
this._wizard.onPageChanged(async (info) => {
|
||||||
this.onWizardPageChanged(info);
|
this.onWizardPageChanged(info);
|
||||||
});
|
});
|
||||||
|
|
||||||
return this._wizard;
|
return this._wizard;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async validate(pageInfo: azdata.window.WizardPageChangeInfo): Promise<boolean> {
|
||||||
|
if (pageInfo.lastPage !== undefined) {
|
||||||
|
let idxLast = pageInfo.lastPage;
|
||||||
|
let lastPage = this._pages[idxLast];
|
||||||
|
if (lastPage && lastPage.validate) {
|
||||||
|
return await lastPage.validate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
|
private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
|
||||||
let idxLast = pageInfo.lastPage;
|
let idxLast = pageInfo.lastPage;
|
||||||
let lastPage = this._pages[idxLast];
|
let lastPage = this._pages[idxLast];
|
||||||
@@ -73,4 +98,8 @@ export class WizardView extends MainViewBase {
|
|||||||
page.onEnter();
|
page.onEnter();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public get wizard(): azdata.window.Wizard | undefined {
|
||||||
|
return this._wizard;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user