mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-08 17:24:01 -05:00
ML - dashboard icons and links (#10153)
* ML - dashboard icons and links
This commit is contained in:
136
extensions/machine-learning/src/common/apiWrapper.ts
Normal file
136
extensions/machine-learning/src/common/apiWrapper.ts
Normal file
@@ -0,0 +1,136 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
/**
|
||||
* Wrapper class to act as a facade over VSCode and Data APIs and allow us to test / mock callbacks into
|
||||
* this API from our code
|
||||
*/
|
||||
export class ApiWrapper {
|
||||
public createOutputChannel(name: string): vscode.OutputChannel {
|
||||
return vscode.window.createOutputChannel(name);
|
||||
}
|
||||
|
||||
public createTerminalWithOptions(options: vscode.TerminalOptions): vscode.Terminal {
|
||||
return vscode.window.createTerminal(options);
|
||||
}
|
||||
|
||||
public getCurrentConnection(): Thenable<azdata.connection.ConnectionProfile> {
|
||||
return azdata.connection.getCurrentConnection();
|
||||
}
|
||||
|
||||
public getCredentials(connectionId: string): Thenable<{ [name: string]: string }> {
|
||||
return azdata.connection.getCredentials(connectionId);
|
||||
}
|
||||
|
||||
public registerCommand(command: string, callback: (...args: any[]) => any, thisArg?: any): vscode.Disposable {
|
||||
return vscode.commands.registerCommand(command, callback, thisArg);
|
||||
}
|
||||
|
||||
public executeCommand<T>(command: string, ...rest: any[]): Thenable<T | undefined> {
|
||||
return vscode.commands.executeCommand(command, ...rest);
|
||||
}
|
||||
public registerTaskHandler(taskId: string, handler: (profile: azdata.IConnectionProfile) => void): void {
|
||||
azdata.tasks.registerTask(taskId, handler);
|
||||
}
|
||||
|
||||
public getUriForConnection(connectionId: string): Thenable<string> {
|
||||
return azdata.connection.getUriForConnection(connectionId);
|
||||
}
|
||||
|
||||
public getProvider<T extends azdata.DataProvider>(providerId: string, providerType: azdata.DataProviderType): T {
|
||||
return azdata.dataprotocol.getProvider<T>(providerId, providerType);
|
||||
}
|
||||
|
||||
public showErrorMessage(message: string, ...items: string[]): Thenable<string | undefined> {
|
||||
return vscode.window.showErrorMessage(message, ...items);
|
||||
}
|
||||
|
||||
public showInfoMessage(message: string, ...items: string[]): Thenable<string | undefined> {
|
||||
return vscode.window.showInformationMessage(message, ...items);
|
||||
}
|
||||
|
||||
public showOpenDialog(options: vscode.OpenDialogOptions): Thenable<vscode.Uri[] | undefined> {
|
||||
return vscode.window.showOpenDialog(options);
|
||||
}
|
||||
|
||||
public startBackgroundOperation(operationInfo: azdata.BackgroundOperationInfo): void {
|
||||
azdata.tasks.startBackgroundOperation(operationInfo);
|
||||
}
|
||||
|
||||
public openExternal(target: vscode.Uri): Thenable<boolean> {
|
||||
return vscode.env.openExternal(target);
|
||||
}
|
||||
|
||||
public getExtension(extensionId: string): vscode.Extension<any> | undefined {
|
||||
return vscode.extensions.getExtension(extensionId);
|
||||
}
|
||||
|
||||
public getConfiguration(section?: string, resource?: vscode.Uri | null): vscode.WorkspaceConfiguration {
|
||||
return vscode.workspace.getConfiguration(section, resource);
|
||||
}
|
||||
|
||||
public createTab(title: string): azdata.window.DialogTab {
|
||||
return azdata.window.createTab(title);
|
||||
}
|
||||
|
||||
public createModelViewDialog(title: string, dialogName?: string, isWide?: boolean): azdata.window.Dialog {
|
||||
return azdata.window.createModelViewDialog(title, dialogName, isWide);
|
||||
}
|
||||
|
||||
public createWizard(title: string): azdata.window.Wizard {
|
||||
return azdata.window.createWizard(title);
|
||||
}
|
||||
|
||||
public createWizardPage(title: string): azdata.window.WizardPage {
|
||||
return azdata.window.createWizardPage(title);
|
||||
}
|
||||
|
||||
public openDialog(dialog: azdata.window.Dialog): void {
|
||||
return azdata.window.openDialog(dialog);
|
||||
}
|
||||
|
||||
public getAllAccounts(): Thenable<azdata.Account[]> {
|
||||
return azdata.accounts.getAllAccounts();
|
||||
}
|
||||
|
||||
public getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{ [key: string]: any }> {
|
||||
return azdata.accounts.getSecurityToken(account, resource);
|
||||
}
|
||||
|
||||
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
|
||||
return vscode.window.showQuickPick(items, options, token);
|
||||
}
|
||||
|
||||
public listDatabases(connectionId: string): Thenable<string[]> {
|
||||
return azdata.connection.listDatabases(connectionId);
|
||||
}
|
||||
|
||||
public openTextDocument(options?: { language?: string; content?: string; }): Thenable<vscode.TextDocument> {
|
||||
return vscode.workspace.openTextDocument(options);
|
||||
}
|
||||
|
||||
public connect(fileUri: string, connectionId: string): Thenable<void> {
|
||||
return azdata.queryeditor.connect(fileUri, connectionId);
|
||||
}
|
||||
|
||||
public runQuery(fileUri: string, options?: Map<string, string>, runCurrentQuery?: boolean): void {
|
||||
azdata.queryeditor.runQuery(fileUri, options, runCurrentQuery);
|
||||
}
|
||||
|
||||
public showTextDocument(uri: vscode.Uri, options?: vscode.TextDocumentShowOptions): Thenable<vscode.TextEditor> {
|
||||
return vscode.window.showTextDocument(uri, options);
|
||||
}
|
||||
|
||||
public createButton(label: string, position?: azdata.window.DialogButtonPosition): azdata.window.Button {
|
||||
return azdata.window.createButton(label, position);
|
||||
}
|
||||
|
||||
public registerWidget(widgetId: string, handler: (view: azdata.ModelView) => void): void {
|
||||
azdata.ui.registerModelViewProvider(widgetId, handler);
|
||||
}
|
||||
}
|
||||
235
extensions/machine-learning/src/common/constants.ts
Normal file
235
extensions/machine-learning/src/common/constants.ts
Normal file
@@ -0,0 +1,235 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as nls from 'vscode-nls';
|
||||
|
||||
const localize = nls.loadMessageBundle();
|
||||
|
||||
export const winPlatform = 'win32';
|
||||
export const pythonBundleVersion = '0.0.1';
|
||||
export const managePackagesCommand = 'jupyter.cmd.managePackages';
|
||||
export const pythonLanguageName = 'Python';
|
||||
export const rLanguageName = 'R';
|
||||
export const rLPackagedFolderName = 'r_packages';
|
||||
|
||||
export const mlEnableMlsCommand = 'mls.command.enableMls';
|
||||
export const mlDisableMlsCommand = 'mls.command.disableMls';
|
||||
export const extensionOutputChannel = 'Machine Learning';
|
||||
export const notebookExtensionName = 'Microsoft.notebook';
|
||||
export const azureSubscriptionsCommand = 'azure.accounts.getSubscriptions';
|
||||
export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups';
|
||||
export const signInToAzureCommand = 'azure.resource.signin';
|
||||
|
||||
// Tasks, commands
|
||||
//
|
||||
export const mlManageLanguagesCommand = 'mls.command.manageLanguages';
|
||||
export const mlsPredictModelCommand = 'mls.command.predictModel';
|
||||
export const mlManageModelsCommand = 'mls.command.manageModels';
|
||||
export const mlImportModelCommand = 'mls.command.importModel';
|
||||
export const mlManagePackagesCommand = 'mls.command.managePackages';
|
||||
export const mlsDependenciesCommand = 'mls.command.dependencies';
|
||||
export const notebookCommandNew = 'notebook.command.new';
|
||||
|
||||
// Configurations
|
||||
//
|
||||
export const mlsConfigKey = 'machineLearningServices';
|
||||
export const pythonPathConfigKey = 'pythonPath';
|
||||
export const pythonEnabledConfigKey = 'enablePython';
|
||||
export const rEnabledConfigKey = 'enableR';
|
||||
export const registeredModelsTableName = 'registeredModelsTableName';
|
||||
export const rPathConfigKey = 'rPath';
|
||||
|
||||
// Localized texts
|
||||
//
|
||||
export const msgYes = localize('msgYes', "Yes");
|
||||
export const msgNo = localize('msgNo', "No");
|
||||
export const managePackageCommandError = localize('mls.managePackages.error', "Either no connection is available or the server does not have external script enabled.");
|
||||
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
|
||||
export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Installing package management dependencies");
|
||||
export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Installing model management dependencies");
|
||||
export const noResultError = localize('mls.noResultError', "No Result returned");
|
||||
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
|
||||
export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that.");
|
||||
export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script.");
|
||||
export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action.");
|
||||
export function confirmInstallPythonPackages(packages: string): string {
|
||||
return localize('mls.installDependencies.confirmInstallPythonPackages'
|
||||
, "The following Python packages are required to install: {0}. Are you sure you want to install?", packages);
|
||||
}
|
||||
export function confirmDeleteModel(modelName: string): string {
|
||||
return localize('models.confirmDeleteModel'
|
||||
, "Are you sure you want to delete model '{0}?", modelName);
|
||||
}
|
||||
export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ...");
|
||||
export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
|
||||
export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
|
||||
export const noConnectionError = localize('mls.packageManager.NoConnection', "No connection selected");
|
||||
export const notebookExtensionNotLoaded = localize('mls.notebookExtensionNotLoaded', "Notebook extension is not loaded");
|
||||
export const mssqlExtensionNotLoaded = localize('mls.mssqlExtensionNotLoaded', "MSSQL extension is not loaded");
|
||||
export const mlsEnabledMessage = localize('mls.enabledMessage', "Machine Learning Services Enabled");
|
||||
export const mlsConfigUpdateFailed = localize('mls.configUpdateFailed', "Failed to modify Machine Learning Services configurations");
|
||||
export const mlsEnableButtonTitle = localize('mls.enableButtonTitle', "Enable");
|
||||
export const mlsDisableButtonTitle = localize('mls.disableButtonTitle', "Disable");
|
||||
export const mlsConfigTitle = localize('mls.configTitle', "Config");
|
||||
export const mlsConfigStatus = localize('mls.configStatus', "Enabled");
|
||||
export const mlsConfigAction = localize('mls.configAction', "Action");
|
||||
export const mlsExternalExecuteScriptTitle = localize('mls.externalExecuteScriptTitle', "External Execute Script");
|
||||
export const mlsPythonLanguageTitle = localize('mls.pythonLanguageTitle', "Python");
|
||||
export const mlsRLanguageTitle = localize('mls.rLanguageTitle', "R");
|
||||
export const downloadError = localize('mls.downloadError', "Error while downloading");
|
||||
export function invalidModelIdError(modelUrl: string | undefined): string { return localize('mls.invalidModelIdError', "Invalid model id. model url: {0}", modelUrl || ''); }
|
||||
export function noArtifactError(modelUrl: string | undefined): string { return localize('mls.noArtifactError', "Model doesn't have any artifact. model url: {0}", modelUrl || ''); }
|
||||
export const downloadingProgress = localize('mls.downloadingProgress', "Downloading");
|
||||
export const pythonConfigError = localize('mls.pythonConfigError', "Python executable is not configured");
|
||||
export const rConfigError = localize('mls.rConfigError', "R executable is not configured");
|
||||
export const installingDependencies = localize('mls.installingDependencies', "Installing dependencies ...");
|
||||
export const resourceNotFoundError = localize('mls.resourceNotFound', "Could not find the specified resource");
|
||||
export const latestVersion = localize('mls.latestVersion', "Latest");
|
||||
export const localhost = 'localhost';
|
||||
export function httpGetRequestError(code: number, message: string): string {
|
||||
return localize('mls.httpGetRequestError', "Package info request failed with error: {0} {1}",
|
||||
code,
|
||||
message);
|
||||
}
|
||||
export function getErrorMessage(error: Error): string { return localize('azure.resource.error', "Error: {0}", error?.message || error?.toString()); }
|
||||
export const notSupportedEventArg = localize('notSupportedEventArg', "Not supported event args");
|
||||
export const extLangInstallTabTitle = localize('extLang.installTabTitle', "Installed");
|
||||
export const extLangLanguageCreatedDate = localize('extLang.languageCreatedDate', "Installed");
|
||||
export const extLangLanguagePlatform = localize('extLang.languagePlatform', "Platform");
|
||||
export const deleteTitle = localize('extLang.delete', "Delete");
|
||||
export const extLangInstallButtonText = localize('extLang.installButtonText', "Install");
|
||||
export const extLangCancelButtonText = localize('extLang.CancelButtonText', "Cancel");
|
||||
export const extLangDoneButtonText = localize('extLang.DoneButtonText', "Close");
|
||||
export const extLangOkButtonText = localize('extLang.OkButtonText', "OK");
|
||||
export const extLangSaveButtonText = localize('extLang.SaveButtonText', "Save");
|
||||
export const extLangLanguageName = localize('extLang.languageName', "Name");
|
||||
export const extLangNewLanguageTabTitle = localize('extLang.newLanguageTabTitle', "Add new");
|
||||
export const extLangFileBrowserTabTitle = localize('extLang.fileBrowserTabTitle', "File Browser");
|
||||
export const extLangDialogTitle = localize('extLang.DialogTitle', "Languages");
|
||||
export const extLangTarget = localize('extLang.Target', "Target");
|
||||
export const extLangLocal = localize('extLang.Local', "localhost");
|
||||
export const extLangExtensionFilePath = localize('extLang.extensionFilePath', "Language extension path");
|
||||
export const extLangExtensionFileLocation = localize('extLang.extensionFileLocation', "Language extension location");
|
||||
export const extLangExtensionFileName = localize('extLang.extensionFileName', "Extension file Name");
|
||||
export const extLangEnvVariables = localize('extLang.envVariables', "Environment variables");
|
||||
export const extLangParameters = localize('extLang.parameters', "Parameters");
|
||||
export const extLangSelectedPath = localize('extLang.selectedPath', "Selected Path");
|
||||
export const extLangInstallFailedError = localize('extLang.installFailedError', "Failed to install language");
|
||||
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
|
||||
|
||||
export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model");
|
||||
export const databaseName = localize('databaseName', "Database name");
|
||||
export const tableName = localize('tableName', "Table name");
|
||||
export const modelName = localize('models.name', "Name");
|
||||
export const modelFileName = localize('models.fileName', "File");
|
||||
export const modelDescription = localize('models.description', "Description");
|
||||
export const modelCreated = localize('models.created', "Date created");
|
||||
export const modelDeployed = localize('models.deployed', "Date deployed");
|
||||
export const modelFramework = localize('models.framework', "Framework");
|
||||
export const modelFrameworkVersion = localize('models.frameworkVersion', "Framework version");
|
||||
export const modelVersion = localize('models.version', "Version");
|
||||
export const browseModels = localize('models.browseButton', "...");
|
||||
export const azureAccount = localize('models.azureAccount', "Azure account");
|
||||
export const azureSignIn = localize('models.azureSignIn', "Sign in to Azure");
|
||||
export const columnDatabase = localize('predict.columnDatabase', "Target database");
|
||||
export const columnTable = localize('predict.columnTable', "Target table");
|
||||
export const inputColumns = localize('predict.inputColumns', "Model input mapping");
|
||||
export const outputColumns = localize('predict.outputColumns', "Model output");
|
||||
export const columnName = localize('predict.columnName', "Target columns");
|
||||
export const dataTypeName = localize('predict.dataTypeName', "Type");
|
||||
export const displayName = localize('predict.displayName', "Display name");
|
||||
export const inputName = localize('predict.inputName', "Required model input features");
|
||||
export const outputName = localize('predict.outputName', "Name");
|
||||
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
|
||||
export const azureGroup = localize('models.azureGroup', "Azure resource group");
|
||||
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
|
||||
export const azureModelFilter = localize('models.azureModelFilter', "Filter");
|
||||
export const azureModels = localize('models.azureModels', "Models");
|
||||
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
|
||||
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
|
||||
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
|
||||
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?");
|
||||
export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Where do you want import models to?");
|
||||
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
|
||||
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details");
|
||||
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
|
||||
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
|
||||
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
|
||||
export const predictModel = localize('models.predictModel', "Predict");
|
||||
export const registerModelTitle = localize('models.RegisterWizard', "Import models");
|
||||
export const importModelTitle = localize('models.importModelTitle', "Import models");
|
||||
export const editModelTitle = localize('models.editModelTitle', "Edit model");
|
||||
export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model");
|
||||
export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions");
|
||||
export const makePredictionDesc = localize('models.makePredictionDesc', "Generates a predicted value or scores using a managed model");
|
||||
export const createNotebookTitle = localize('models.createNotebookTitle', "Create notebook");
|
||||
export const createNotebookDesc = localize('models.createNotebookDesc', "Run experiments and create models");
|
||||
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
|
||||
export const modelUpdatedSuccessfully = localize('models.modelUpdatedSuccessfully', "Model updated successfully");
|
||||
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
|
||||
export const localModelSource = localize('models.localModelSource', "File upload");
|
||||
export const localModelPageTitle = localize('models.localModelPageTitle', "Upload model file");
|
||||
export const azureModelSource = localize('models.azureModelSource', "Azure Machine Learning");
|
||||
export const azureModelPageTitle = localize('models.azureModelPageTitle', "Import from Azure Machine Learning");
|
||||
export const importedModelsPageTitle = localize('models.importedModelsPageTitle', "Select imported model");
|
||||
export const registeredModelsSource = localize('models.registeredModelsSource', "Imported models");
|
||||
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
|
||||
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
|
||||
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
|
||||
export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict");
|
||||
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
|
||||
export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table");
|
||||
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
|
||||
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
|
||||
export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); }
|
||||
export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
|
||||
export function invalidImportTableSchemaError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableSchemaError', "Table schema is not supported for model import. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
|
||||
|
||||
export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
|
||||
export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
|
||||
export const dashboardTitle = localize('dashboardTitle', "Machine Learning");
|
||||
export const dashboardDesc = localize('dashboardDesc', "Machine Learning for SQL Databases");
|
||||
export const dashboardLinksTitle = localize('dashboardLinksTitle', "Useful links");
|
||||
export const dashboardVideoLinksTitle = localize('dashboardVideoLinksTitle', "Video tutorials");
|
||||
export const showMoreTitle = localize('showMoreTitle', "Show more");
|
||||
export const showLessTitle = localize('showLessTitle', "Show less");
|
||||
export const learnMoreTitle = localize('learnMoreTitle', "Learn more");
|
||||
export const sqlMlDocTitle = localize('sqlMlDocTitle', "SQL machine learning documentation");
|
||||
export const sqlMlDocDesc = localize('sqlMlDocDesc', "Learn how to use machine learning in SQL Server and SQL on Azure, to run Python and R scripts on relational data.");
|
||||
export const sqlMlsDocTitle = localize('sqlMlsDocTitle', "SQL Server Machine Learning Services (Python and R)");
|
||||
export const sqlMlsDocDesc = localize('sqlMlsDocDesc', "Get started with Machine Learning Services on SQL Server and how to install it on Windows and Linux.");
|
||||
export const sqlMlsAzureDocTitle = localize('sqlMlsAzureDocTitle', "Machine Learning Services in Azure SQL Managed Instance (preview)");
|
||||
export const sqlMlsAzureDocDesc = localize('sqlMlsAzureDocDesc', "Get started with Machine Learning Services in Azure SQL Managed Instances.");
|
||||
export const mlsInstallOdbcDocTitle = localize('mlsInstallObdcDocTitle', "Install the Microsoft ODBC driver for SQL Server");
|
||||
export const mlsInstallOdbcDocDesc = localize('mlsInstallOdbcDocDesc', "This document explains how to install the Microsoft ODBC Driver for SQL Server.");
|
||||
|
||||
// Links
|
||||
//
|
||||
export const mlsDocuments = 'https://docs.microsoft.com/sql/advanced-analytics/?view=sql-server-ver15';
|
||||
export const odbcDriverWindowsDocuments = 'https://docs.microsoft.com/sql/connect/odbc/windows/microsoft-odbc-driver-for-sql-server-on-windows?view=sql-server-ver15';
|
||||
export const odbcDriverLinuxDocuments = 'https://docs.microsoft.com/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-for-sql-server?view=sql-server-ver15';
|
||||
export const mlDocLink = 'https://docs.microsoft.com/sql/machine-learning/';
|
||||
export const mlsDocLink = 'https://docs.microsoft.com/sql/machine-learning/what-is-sql-server-machine-learning';
|
||||
export const mlsAzureDocLink = 'https://docs.microsoft.com/azure/sql-database/sql-database-managed-instance-machine-learning-services-overview';
|
||||
export const installMlsWindowsDocs = 'https://docs.microsoft.com/sql/advanced-analytics/install/sql-machine-learning-services-windows-install?view=sql-server-ver15';
|
||||
|
||||
// CSS Styles
|
||||
//
|
||||
export namespace cssStyles {
|
||||
export const title = { 'font-size': '14px', 'font-weight': '600' };
|
||||
export const tableHeader = { 'text-align': 'left', 'font-weight': 'bold', 'text-transform': 'uppercase', 'font-size': '10px', 'user-select': 'text', 'border': 'none' };
|
||||
export const tableRow = { 'border-top': 'solid 1px #ccc', 'border-bottom': 'solid 1px #ccc', 'border-left': 'none', 'border-right': 'none' };
|
||||
export const hyperlink = { 'user-select': 'text', 'color': '#0078d4', 'text-decoration': 'underline', 'cursor': 'pointer' };
|
||||
export const text = { 'margin-block-start': '0px', 'margin-block-end': '0px' };
|
||||
export const overflowEllipsisText = { ...text, 'overflow': 'hidden', 'text-overflow': 'ellipsis' };
|
||||
export const nonSelectableText = { ...cssStyles.text, 'user-select': 'none' };
|
||||
export const tabHeaderText = { 'margin-block-start': '2px', 'margin-block-end': '0px', 'user-select': 'none' };
|
||||
export const selectedResourceHeaderTab = { 'font-weight': 'bold', 'color': '' };
|
||||
export const unselectedResourceHeaderTab = { 'font-weight': '', 'color': '#0078d4' };
|
||||
export const selectedTabDiv = { 'border-bottom': '2px solid #000' };
|
||||
export const unselectedTabDiv = { 'border-bottom': '1px solid #ccc' };
|
||||
export const lastUpdatedText = { ...text, 'color': '#595959' };
|
||||
export const errorText = { ...text, 'color': 'red' };
|
||||
}
|
||||
45
extensions/machine-learning/src/common/eventEmitter.ts
Normal file
45
extensions/machine-learning/src/common/eventEmitter.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
export class EventEmitterCollection extends vscode.Disposable {
|
||||
private _events: Map<string, vscode.EventEmitter<any>[]> = new Map<string, vscode.EventEmitter<any>[]>();
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor() {
|
||||
super(() => this.dispose());
|
||||
|
||||
}
|
||||
|
||||
public on(evt: string, listener: (e: any) => any, thisArgs?: any) {
|
||||
if (!this._events.has(evt)) {
|
||||
this._events.set(evt, []);
|
||||
}
|
||||
let eventEmitter = new vscode.EventEmitter<any>();
|
||||
eventEmitter.event(listener, thisArgs);
|
||||
this._events.get(evt)?.push(eventEmitter);
|
||||
return this;
|
||||
}
|
||||
|
||||
public fire(evt: string, arg?: any) {
|
||||
if (!this._events.has(evt)) {
|
||||
this._events.set(evt, []);
|
||||
}
|
||||
this._events.get(evt)?.forEach(eventEmitter => {
|
||||
eventEmitter.fire(arg);
|
||||
});
|
||||
}
|
||||
|
||||
public dispose(): any {
|
||||
this._events.forEach(events => {
|
||||
events.forEach(event => {
|
||||
event.dispose();
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
80
extensions/machine-learning/src/common/httpClient.ts
Normal file
80
extensions/machine-learning/src/common/httpClient.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as fs from 'fs';
|
||||
import * as request from 'request';
|
||||
import * as constants from './constants';
|
||||
|
||||
const DownloadTimeout = 20000;
|
||||
const GetTimeout = 10000;
|
||||
export class HttpClient {
|
||||
|
||||
public async fetch(url: string): Promise<any> {
|
||||
return new Promise<any>((resolve, reject) => {
|
||||
request.get(url, { timeout: GetTimeout }, (error, response, body) => {
|
||||
if (error) {
|
||||
return reject(error);
|
||||
}
|
||||
|
||||
if (response.statusCode === 404) {
|
||||
return reject(constants.resourceNotFoundError);
|
||||
}
|
||||
|
||||
if (response.statusCode !== 200) {
|
||||
return reject(
|
||||
constants.httpGetRequestError(
|
||||
response.statusCode,
|
||||
response.statusMessage));
|
||||
}
|
||||
|
||||
resolve(body);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
public download(downloadUrl: string, targetPath: string, outputChannel: vscode.OutputChannel): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
|
||||
let totalMegaBytes: number | undefined = undefined;
|
||||
let receivedBytes = 0;
|
||||
let printThreshold = 0.1;
|
||||
let downloadRequest = request.get(downloadUrl, { timeout: DownloadTimeout })
|
||||
.on('error', downloadError => {
|
||||
outputChannel.appendLine(constants.downloadError);
|
||||
reject(downloadError);
|
||||
})
|
||||
.on('response', (response) => {
|
||||
if (response.statusCode !== 200) {
|
||||
outputChannel.appendLine(constants.downloadError);
|
||||
return reject(response.statusMessage);
|
||||
}
|
||||
let contentLength = response.headers['content-length'];
|
||||
let totalBytes = parseInt(contentLength || '0');
|
||||
totalMegaBytes = totalBytes / (1024 * 1024);
|
||||
outputChannel.appendLine(`'Downloading' (0 / ${totalMegaBytes.toFixed(2)} MB)`);
|
||||
})
|
||||
.on('data', (data) => {
|
||||
receivedBytes += data.length;
|
||||
if (totalMegaBytes) {
|
||||
let receivedMegaBytes = receivedBytes / (1024 * 1024);
|
||||
let percentage = receivedMegaBytes / totalMegaBytes;
|
||||
if (percentage >= printThreshold) {
|
||||
outputChannel.appendLine(`${constants.downloadingProgress} (${receivedMegaBytes.toFixed(2)} / ${totalMegaBytes.toFixed(2)} MB)`);
|
||||
printThreshold += 0.1;
|
||||
}
|
||||
}
|
||||
});
|
||||
downloadRequest.pipe(fs.createWriteStream(targetPath))
|
||||
.on('close', async () => {
|
||||
resolve();
|
||||
})
|
||||
.on('error', (downloadError) => {
|
||||
reject(downloadError);
|
||||
downloadRequest.abort();
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
91
extensions/machine-learning/src/common/processService.ts
Normal file
91
extensions/machine-learning/src/common/processService.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as childProcess from 'child_process';
|
||||
|
||||
const ExecScriptsTimeoutInSeconds = 600000;
|
||||
export class ProcessService {
|
||||
|
||||
public timeout = ExecScriptsTimeoutInSeconds;
|
||||
|
||||
public async execScripts(exeFilePath: string, scripts: string[], args?: string[], outputChannel?: vscode.OutputChannel): Promise<string> {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
|
||||
const scriptExecution = childProcess.spawn(exeFilePath, args);
|
||||
let timer: NodeJS.Timeout;
|
||||
let output: string = '';
|
||||
scripts.forEach(script => {
|
||||
scriptExecution.stdin.write(`${script}\n`);
|
||||
});
|
||||
scriptExecution.stdin.end();
|
||||
|
||||
// Add listeners to print stdout and stderr if an output channel was provided
|
||||
|
||||
scriptExecution.stdout.on('data', data => {
|
||||
if (outputChannel) {
|
||||
this.outputDataChunk(data, outputChannel, ' stdout: ');
|
||||
}
|
||||
output = output + data.toString();
|
||||
});
|
||||
scriptExecution.stderr.on('data', data => {
|
||||
if (outputChannel) {
|
||||
this.outputDataChunk(data, outputChannel, ' stderr: ');
|
||||
}
|
||||
output = output + data.toString();
|
||||
});
|
||||
|
||||
scriptExecution.on('exit', (code) => {
|
||||
if (timer) {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
if (code === 0) {
|
||||
resolve(output);
|
||||
} else {
|
||||
reject(`Process exited with code: ${code}. output: ${output}`);
|
||||
}
|
||||
|
||||
});
|
||||
timer = setTimeout(() => {
|
||||
try {
|
||||
scriptExecution.kill();
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
}
|
||||
}, this.timeout);
|
||||
});
|
||||
}
|
||||
|
||||
public async executeBufferedCommand(cmd: string, outputChannel?: vscode.OutputChannel): Promise<string> {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
if (outputChannel) {
|
||||
outputChannel.appendLine(` > ${cmd}`);
|
||||
}
|
||||
|
||||
let child = childProcess.exec(cmd, {
|
||||
timeout: this.timeout
|
||||
}, (err, stdout) => {
|
||||
if (err) {
|
||||
reject(err);
|
||||
} else {
|
||||
resolve(stdout);
|
||||
}
|
||||
});
|
||||
|
||||
// Add listeners to print stdout and stderr if an output channel was provided
|
||||
if (outputChannel) {
|
||||
child.stdout.on('data', data => { this.outputDataChunk(data, outputChannel, ' stdout: '); });
|
||||
child.stderr.on('data', data => { this.outputDataChunk(data, outputChannel, ' stderr: '); });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private outputDataChunk(data: string | Buffer, outputChannel: vscode.OutputChannel, header: string): void {
|
||||
data.toString().split(/\r?\n/)
|
||||
.forEach(line => {
|
||||
outputChannel.appendLine(header + line);
|
||||
});
|
||||
}
|
||||
}
|
||||
214
extensions/machine-learning/src/common/queryRunner.ts
Normal file
214
extensions/machine-learning/src/common/queryRunner.ts
Normal file
@@ -0,0 +1,214 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
import { ApiWrapper } from './apiWrapper';
|
||||
import * as constants from '../common/constants';
|
||||
import * as utils from '../common/utils';
|
||||
|
||||
const maxNumberOfRetries = 2;
|
||||
|
||||
const listPythonPackagesQuery = `
|
||||
Declare @tablevar table(name NVARCHAR(MAX), version NVARCHAR(MAX))
|
||||
insert into @tablevar(name, version)
|
||||
EXEC sp_execute_external_script
|
||||
@language=N'Python',
|
||||
@script=N'import pkg_resources
|
||||
import pandas
|
||||
OutputDataSet = pandas.DataFrame([(d.project_name, d.version) for d in pkg_resources.working_set])'
|
||||
select e.name, version from sys.external_libraries e join @tablevar t on e.name = t.name
|
||||
where [language] = 'PYTHON'
|
||||
`;
|
||||
|
||||
const listRPackagesQuery = `
|
||||
Declare @tablevar table(name NVARCHAR(MAX), version NVARCHAR(MAX))
|
||||
insert into @tablevar(name, version)
|
||||
EXEC sp_execute_external_script
|
||||
@language=N'R',
|
||||
@script=N'
|
||||
OutputDataSet <- as.data.frame(installed.packages()[,c(1,3)])'
|
||||
|
||||
select e.name, version from sys.external_libraries e join @tablevar t on e.name = t.name
|
||||
where [language] = 'R'
|
||||
`;
|
||||
|
||||
const checkMlInstalledQuery = `
|
||||
Declare @tablevar table(name NVARCHAR(MAX), min INT, max INT, config_value bit, run_value bit)
|
||||
insert into @tablevar(name, min, max, config_value, run_value) exec sp_configure
|
||||
|
||||
Declare @external_script_enabled bit
|
||||
SELECT @external_script_enabled=config_value FROM @tablevar WHERE name = 'external scripts enabled'
|
||||
SELECT @external_script_enabled`;
|
||||
|
||||
const checkLanguageInstalledQuery = `
|
||||
|
||||
SELECT is_installed
|
||||
FROM sys.dm_db_external_language_stats s, sys.external_languages l
|
||||
WHERE s.external_language_id = l.external_language_id AND language = '#LANGUAGE#'`;
|
||||
|
||||
const modifyExternalScriptConfigQuery = `
|
||||
|
||||
EXEC sp_configure 'external scripts enabled', #CONFIG_VALUE#;
|
||||
RECONFIGURE WITH OVERRIDE;
|
||||
|
||||
Declare @tablevar table(name NVARCHAR(MAX), min INT, max INT, config_value bit, run_value bit)
|
||||
insert into @tablevar(name, min, max, config_value, run_value) exec sp_configure
|
||||
|
||||
Declare @external_script_enabled bit
|
||||
SELECT @external_script_enabled=config_value FROM @tablevar WHERE name = 'external scripts enabled'
|
||||
SELECT @external_script_enabled`;
|
||||
|
||||
/**
|
||||
* SQL Query runner
|
||||
*/
|
||||
export class QueryRunner {
|
||||
|
||||
constructor(private _apiWrapper: ApiWrapper) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns python packages installed in SQL server instance
|
||||
* @param connection SQL Connection
|
||||
*/
|
||||
public async getPythonPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
return this.getPackages(connection, databaseName, listPythonPackagesQuery);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns python packages installed in SQL server instance
|
||||
* @param connection SQL Connection
|
||||
*/
|
||||
public async getRPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
return this.getPackages(connection, databaseName, listRPackagesQuery);
|
||||
}
|
||||
|
||||
private async getPackages(connection: azdata.connection.ConnectionProfile, databaseName: string, script: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [];
|
||||
let result: azdata.SimpleExecuteResult | undefined = undefined;
|
||||
|
||||
for (let index = 0; index < maxNumberOfRetries; index++) {
|
||||
result = await this.runQuery(connection, utils.getScriptWithDBChange(connection.databaseName, databaseName, script));
|
||||
if (result && result.rowCount > 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (result && result.rows.length > 0) {
|
||||
packages = result.rows.map(row => {
|
||||
return {
|
||||
name: row[0].displayValue,
|
||||
version: row[1].displayValue
|
||||
};
|
||||
});
|
||||
}
|
||||
return packages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates External Script Config in a SQL server instance
|
||||
* @param connection SQL Connection
|
||||
* @param enable if true the config will be enabled otherwise it will be disabled
|
||||
*/
|
||||
public async updateExternalScriptConfig(connection: azdata.connection.ConnectionProfile, enable: boolean): Promise<void> {
|
||||
let query = modifyExternalScriptConfigQuery;
|
||||
let configValue = enable ? '1' : '0';
|
||||
query = query.replace('#CONFIG_VALUE#', configValue);
|
||||
|
||||
await this.runQuery(connection, query);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if python installed in the give SQL server instance
|
||||
*/
|
||||
public async isPythonInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
|
||||
return this.isLanguageInstalled(connection, constants.pythonLanguageName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if R installed in the give SQL server instance
|
||||
*/
|
||||
public async isRInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
|
||||
return this.isLanguageInstalled(connection, constants.rLanguageName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if language installed in the give SQL server instance
|
||||
*/
|
||||
private async isLanguageInstalled(connection: azdata.connection.ConnectionProfile, language: string): Promise<boolean> {
|
||||
let result = await this.runQuery(connection, checkLanguageInstalledQuery.replace('#LANGUAGE#', language));
|
||||
let isInstalled = false;
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
isInstalled = result.rows[0][0].displayValue === '1';
|
||||
}
|
||||
return isInstalled;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if mls is installed in the give SQL server instance
|
||||
*/
|
||||
public async isMachineLearningServiceEnabled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
|
||||
let result = await this.runQuery(connection, checkMlInstalledQuery);
|
||||
let isEnabled = false;
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
isEnabled = result.rows[0][0].displayValue === '1';
|
||||
}
|
||||
return isEnabled;
|
||||
}
|
||||
|
||||
public async runQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
let result: azdata.SimpleExecuteResult | undefined = undefined;
|
||||
try {
|
||||
if (connection) {
|
||||
let connectionUri = await this._apiWrapper.getUriForConnection(connection.connectionId);
|
||||
let queryProvider = this._apiWrapper.getProvider<azdata.QueryProvider>(connection.providerId, azdata.DataProviderType.QueryProvider);
|
||||
if (queryProvider) {
|
||||
result = await queryProvider.runQueryAndReturn(connectionUri, query);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the query but doesn't fail it is fails
|
||||
* @param connection SQL connection
|
||||
* @param query query to run
|
||||
*/
|
||||
public async safeRunQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
try {
|
||||
return await this.runQuery(connection, query);
|
||||
} catch (error) {
|
||||
//console.log(error);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes the query but doesn't fail it is fails
|
||||
* @param connection SQL connection
|
||||
* @param query query to run
|
||||
*/
|
||||
public async runWithDatabaseChange(connection: azdata.connection.ConnectionProfile, query: string, queryDb: string): Promise<azdata.SimpleExecuteResult | undefined> {
|
||||
if (connection) {
|
||||
try {
|
||||
return await this.runQuery(connection, `
|
||||
USE [${utils.doubleEscapeSingleBrackets(queryDb)}]
|
||||
${query}`);
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
}
|
||||
finally {
|
||||
this.safeRunQuery(connection, `USE [${utils.doubleEscapeSingleBrackets(connection.databaseName || 'master')}]`);
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
261
extensions/machine-learning/src/common/utils.ts
Normal file
261
extensions/machine-learning/src/common/utils.ts
Normal file
@@ -0,0 +1,261 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as fs from 'fs';
|
||||
import * as constants from './constants';
|
||||
import { promisify } from 'util';
|
||||
import { ApiWrapper } from './apiWrapper';
|
||||
|
||||
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
|
||||
let tempFilePath: string = '';
|
||||
try {
|
||||
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, content);
|
||||
let result = await command(tempFilePath);
|
||||
return result;
|
||||
}
|
||||
finally {
|
||||
await deleteFile(tempFilePath);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a file
|
||||
* @param filePath file path
|
||||
*/
|
||||
export async function deleteFile(filePath: string) {
|
||||
if (filePath) {
|
||||
await fs.promises.unlink(filePath);
|
||||
}
|
||||
}
|
||||
|
||||
export async function readFileInHex(filePath: string): Promise<string> {
|
||||
let buffer = await fs.promises.readFile(filePath);
|
||||
return `0X${buffer.toString('hex')}`;
|
||||
}
|
||||
|
||||
export async function exists(path: string): Promise<boolean> {
|
||||
return promisify(fs.exists)(path);
|
||||
}
|
||||
|
||||
export async function createFolder(dirPath: string): Promise<void> {
|
||||
let folderExists = await exists(dirPath);
|
||||
if (!folderExists) {
|
||||
await fs.promises.mkdir(dirPath);
|
||||
}
|
||||
}
|
||||
|
||||
export function getPythonInstallationLocation(rootFolder: string) {
|
||||
return path.join(rootFolder, 'python');
|
||||
}
|
||||
|
||||
export function getPythonExePath(rootFolder: string): string {
|
||||
return path.join(
|
||||
getPythonInstallationLocation(rootFolder),
|
||||
constants.pythonBundleVersion,
|
||||
process.platform === constants.winPlatform ? 'python.exe' : 'bin/python3');
|
||||
}
|
||||
|
||||
export function getPackageFilePath(rootFolder: string, packageName: string): string {
|
||||
return path.join(
|
||||
rootFolder,
|
||||
constants.rLPackagedFolderName,
|
||||
packageName);
|
||||
}
|
||||
|
||||
export function getRPackagesFolderPath(rootFolder: string): string {
|
||||
return path.join(
|
||||
rootFolder,
|
||||
constants.rLPackagedFolderName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compares two version strings to see which is greater.
|
||||
* @param first First version string to compare.
|
||||
* @param second Second version string to compare.
|
||||
* @returns 1 if the first version is greater, -1 if it's less, and 0 otherwise.
|
||||
*/
|
||||
export function comparePackageVersions(first: string, second: string): number {
|
||||
let firstVersion = first.split('.').map(numStr => Number.parseInt(numStr));
|
||||
let secondVersion = second.split('.').map(numStr => Number.parseInt(numStr));
|
||||
|
||||
// If versions have different lengths, then append zeroes to the shorter one
|
||||
if (firstVersion.length > secondVersion.length) {
|
||||
let diff = firstVersion.length - secondVersion.length;
|
||||
secondVersion = secondVersion.concat(new Array(diff).fill(0));
|
||||
} else if (secondVersion.length > firstVersion.length) {
|
||||
let diff = secondVersion.length - firstVersion.length;
|
||||
firstVersion = firstVersion.concat(new Array(diff).fill(0));
|
||||
}
|
||||
|
||||
for (let i = 0; i < firstVersion.length; ++i) {
|
||||
if (firstVersion[i] > secondVersion[i]) {
|
||||
return 1;
|
||||
} else if (firstVersion[i] < secondVersion[i]) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
export function sortPackageVersions(versions: string[], ascending: boolean = true) {
|
||||
return versions.sort((first, second) => {
|
||||
let compareResult = comparePackageVersions(first, second);
|
||||
if (ascending) {
|
||||
return compareResult;
|
||||
} else {
|
||||
return compareResult * -1;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export function isWindows(): boolean {
|
||||
return process.platform === 'win32';
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes all single-quotes (') by prefixing them with another single quote ('')
|
||||
* ' => ''
|
||||
* @param value The string to escape
|
||||
*/
|
||||
export function doubleEscapeSingleQuotes(value: string | undefined): string {
|
||||
return value ? value.replace(/'/g, '\'\'') : '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Escapes all single-bracket ([]) by replacing them with another bracket quote ([[]])
|
||||
* ' => ''
|
||||
* @param value The string to escape
|
||||
*/
|
||||
export function doubleEscapeSingleBrackets(value: string | undefined): string {
|
||||
return value ? value.replace(/\[/g, '[[').replace(/\]/g, ']]') : '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
*/
|
||||
export async function executeTasks<T>(apiWrapper: ApiWrapper, taskName: string, dependencies: PromiseLike<T>[], parallel: boolean): Promise<T[]> {
|
||||
return new Promise<T[]>((resolve, reject) => {
|
||||
let msgTaskName = taskName;
|
||||
apiWrapper.startBackgroundOperation({
|
||||
displayName: msgTaskName,
|
||||
description: msgTaskName,
|
||||
isCancelable: false,
|
||||
operation: async op => {
|
||||
try {
|
||||
let result: T[] = [];
|
||||
// Install required packages
|
||||
//
|
||||
if (parallel) {
|
||||
result = await Promise.all(dependencies);
|
||||
} else {
|
||||
for (let index = 0; index < dependencies.length; index++) {
|
||||
result.push(await dependencies[index]);
|
||||
}
|
||||
}
|
||||
op.updateStatus(azdata.TaskStatus.Succeeded);
|
||||
resolve(result);
|
||||
} catch (error) {
|
||||
let errorMsg = constants.taskFailedError(taskName, error ? error.message : '');
|
||||
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
|
||||
reject(errorMsg);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Promise<boolean> {
|
||||
let choices: { [id: string]: boolean } = {};
|
||||
choices[constants.msgYes] = true;
|
||||
choices[constants.msgNo] = false;
|
||||
|
||||
let options = {
|
||||
placeHolder: message
|
||||
};
|
||||
|
||||
let result = await apiWrapper.showQuickPick(Object.keys(choices).map(c => {
|
||||
return {
|
||||
label: c
|
||||
};
|
||||
}), options);
|
||||
if (result === undefined) {
|
||||
throw Error('invalid selection');
|
||||
}
|
||||
|
||||
return choices[result.label] || false;
|
||||
}
|
||||
|
||||
export function makeLinuxPath(filePath: string): string {
|
||||
const parts = filePath.split('\\');
|
||||
return parts.join('/');
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param currentDb Wraps the given script with database switch scripts
|
||||
* @param databaseName
|
||||
* @param script
|
||||
*/
|
||||
export function getScriptWithDBChange(currentDb: string, databaseName: string, script: string): string {
|
||||
if (!currentDb) {
|
||||
currentDb = 'master';
|
||||
}
|
||||
let escapedDbName = doubleEscapeSingleBrackets(databaseName);
|
||||
let escapedCurrentDbName = doubleEscapeSingleBrackets(currentDb);
|
||||
return `
|
||||
USE [${escapedDbName}]
|
||||
${script}
|
||||
USE [${escapedCurrentDbName}]
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns full name of model registration table
|
||||
* @param config config
|
||||
*/
|
||||
export function getRegisteredModelsThreePartsName(db: string, table: string, schema: string) {
|
||||
const dbName = doubleEscapeSingleBrackets(db);
|
||||
const schemaName = doubleEscapeSingleBrackets(schema);
|
||||
const tableName = doubleEscapeSingleBrackets(table);
|
||||
return `[${dbName}].[${schemaName}].[${tableName}]`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns full name of model registration table
|
||||
* @param config config object
|
||||
*/
|
||||
export function getRegisteredModelsTwoPartsName(table: string, schema: string) {
|
||||
const schemaName = doubleEscapeSingleBrackets(schema);
|
||||
const tableName = doubleEscapeSingleBrackets(table);
|
||||
return `[${schemaName}].[${tableName}]`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Write a file using a hex string
|
||||
* @param content file content
|
||||
*/
|
||||
export async function writeFileFromHex(content: string): Promise<string> {
|
||||
content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content;
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex'));
|
||||
return tempFilePath;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param filePath Returns file name
|
||||
*/
|
||||
export function getFileName(filePath: string) {
|
||||
if (filePath) {
|
||||
return filePath.replace(/^.*[\\\/]/, '');
|
||||
} else {
|
||||
return '';
|
||||
}
|
||||
}
|
||||
138
extensions/machine-learning/src/configurations/config.ts
Normal file
138
extensions/machine-learning/src/configurations/config.ts
Normal file
@@ -0,0 +1,138 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as constants from '../common/constants';
|
||||
import { promises as fs } from 'fs';
|
||||
import * as path from 'path';
|
||||
import { PackageConfigModel } from './packageConfigModel';
|
||||
|
||||
const configFileName = 'config.json';
|
||||
const defaultPythonExecutable = 'python';
|
||||
const defaultRExecutable = 'r';
|
||||
|
||||
|
||||
/**
|
||||
* Extension Configuration from app settings
|
||||
*/
|
||||
export class Config {
|
||||
|
||||
private _configValues: any;
|
||||
|
||||
constructor(private _root: string, private _apiWrapper: ApiWrapper) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the config values
|
||||
*/
|
||||
public async load(): Promise<void> {
|
||||
const rawConfig = await fs.readFile(path.join(this._root, configFileName));
|
||||
this._configValues = JSON.parse(rawConfig.toString());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the config value of required python packages
|
||||
*/
|
||||
public get requiredSqlPythonPackages(): PackageConfigModel[] {
|
||||
return this._configValues.sqlPackageManagement.requiredPythonPackages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the config value of required r packages
|
||||
*/
|
||||
public get requiredSqlRPackages(): PackageConfigModel[] {
|
||||
return this._configValues.sqlPackageManagement.requiredRPackages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns r packages repository
|
||||
*/
|
||||
public get rPackagesRepository(): string {
|
||||
return this._configValues.sqlPackageManagement.rPackagesRepository;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns python path from user settings
|
||||
*/
|
||||
public get pythonExecutable(): string {
|
||||
return this.config.get(constants.pythonPathConfigKey) || defaultPythonExecutable;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if python package management is enabled
|
||||
*/
|
||||
public get pythonEnabled(): boolean {
|
||||
return this.config.get(constants.pythonEnabledConfigKey) || false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if r package management is enabled
|
||||
*/
|
||||
public get rEnabled(): boolean {
|
||||
return this.config.get(constants.rEnabledConfigKey) || false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns registered models table name
|
||||
*/
|
||||
public get registeredModelTableName(): string {
|
||||
return this._configValues.modelManagement.registeredModelsTableName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns registered models table schema name
|
||||
*/
|
||||
public get registeredModelTableSchemaName(): string {
|
||||
return this._configValues.modelManagement.registeredModelsTableSchemaName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns registered models table name
|
||||
*/
|
||||
public get registeredModelDatabaseName(): string {
|
||||
return this._configValues.modelManagement.registeredModelsDatabaseName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns Azure ML API
|
||||
*/
|
||||
public get amlModelManagementUrl(): string {
|
||||
return this._configValues.modelManagement.amlModelManagementUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns Azure ML API
|
||||
*/
|
||||
public get amlExperienceUrl(): string {
|
||||
return this._configValues.modelManagement.amlExperienceUrl;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns Azure ML API Version
|
||||
*/
|
||||
public get amlApiVersion(): string {
|
||||
return this._configValues.modelManagement.amlApiVersion;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns model management python packages
|
||||
*/
|
||||
public get modelsRequiredPythonPackages(): PackageConfigModel[] {
|
||||
return this._configValues.modelManagement.requiredPythonPackages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns r path from user settings
|
||||
*/
|
||||
public get rExecutable(): string {
|
||||
return this.config.get(constants.rPathConfigKey) || defaultRExecutable;
|
||||
}
|
||||
|
||||
private get config(): vscode.WorkspaceConfiguration {
|
||||
return this._apiWrapper.getConfiguration(constants.mlsConfigKey);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
/**
|
||||
* The model for package config value
|
||||
*/
|
||||
export interface PackageConfigModel {
|
||||
|
||||
/**
|
||||
* Package name
|
||||
*/
|
||||
name: string;
|
||||
|
||||
/**
|
||||
* Package version
|
||||
*/
|
||||
version?: string;
|
||||
|
||||
/**
|
||||
* Package repository
|
||||
*/
|
||||
repository?: string;
|
||||
|
||||
/**
|
||||
* Package download url
|
||||
*/
|
||||
downloadUrl?: string;
|
||||
|
||||
/**
|
||||
* Package file name if package has download url
|
||||
*/
|
||||
fileName?: string;
|
||||
}
|
||||
165
extensions/machine-learning/src/controllers/mainController.ts
Normal file
165
extensions/machine-learning/src/controllers/mainController.ts
Normal file
@@ -0,0 +1,165 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { ProcessService } from '../common/processService';
|
||||
import { Config } from '../configurations/config';
|
||||
import { PackageManagementService } from '../packageManagement/packageManagementService';
|
||||
import { HttpClient } from '../common/httpClient';
|
||||
import { ModelManagementController } from '../views/models/modelManagementController';
|
||||
import { DeployedModelService } from '../modelManagement/deployedModelService';
|
||||
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
|
||||
import { ModelPythonClient } from '../modelManagement/modelPythonClient';
|
||||
import { PredictService } from '../prediction/predictService';
|
||||
import { DashboardWidget } from '../views/widgets/dashboardWidget';
|
||||
import { ModelConfigRecent } from '../modelManagement/modelConfigRecent';
|
||||
|
||||
/**
|
||||
* The main controller class that initializes the extension
|
||||
*/
|
||||
export default class MainController implements vscode.Disposable {
|
||||
private _outputChannel: vscode.OutputChannel;
|
||||
private _rootPath = this._context.extensionPath;
|
||||
private _config: Config;
|
||||
|
||||
public constructor(
|
||||
private _context: vscode.ExtensionContext,
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _queryRunner: QueryRunner,
|
||||
private _processService: ProcessService,
|
||||
private _packageManager?: PackageManager,
|
||||
private _packageManagementService?: PackageManagementService,
|
||||
private _httpClient?: HttpClient
|
||||
) {
|
||||
this._outputChannel = this._apiWrapper.createOutputChannel(constants.extensionOutputChannel);
|
||||
this._rootPath = this._context.extensionPath;
|
||||
this._config = new Config(this._rootPath, this._apiWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deactivates the extension
|
||||
*/
|
||||
public deactivate(): void {
|
||||
}
|
||||
|
||||
/**
|
||||
* Activates the extension
|
||||
*/
|
||||
public async activate(): Promise<boolean> {
|
||||
await this.initialize();
|
||||
return Promise.resolve(true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an instance of Server Installation from notebook extension
|
||||
*/
|
||||
private async getNotebookExtensionApis(): Promise<nbExtensionApis.IExtensionApi> {
|
||||
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
|
||||
if (nbExtension) {
|
||||
await nbExtension.activate();
|
||||
return (nbExtension.exports as nbExtensionApis.IExtensionApi);
|
||||
} else {
|
||||
throw new Error(constants.notebookExtensionNotLoaded);
|
||||
}
|
||||
}
|
||||
|
||||
private async initialize(): Promise<void> {
|
||||
|
||||
this._outputChannel.show(true);
|
||||
let nbApis = await this.getNotebookExtensionApis();
|
||||
await this._config.load();
|
||||
|
||||
let packageManager = this.getPackageManager(nbApis);
|
||||
this._apiWrapper.registerCommand(constants.mlManagePackagesCommand, (async () => {
|
||||
await packageManager.managePackages();
|
||||
}));
|
||||
|
||||
// External Languages
|
||||
//
|
||||
let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
|
||||
let modelRecentService = new ModelConfigRecent(this._context.globalState);
|
||||
|
||||
// Model Management
|
||||
//
|
||||
let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter, modelRecentService);
|
||||
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
|
||||
let predictService = new PredictService(this._apiWrapper, this._queryRunner);
|
||||
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
|
||||
azureModelsService, registeredModelService, predictService);
|
||||
|
||||
let dashboardWidget = new DashboardWidget(this._apiWrapper, this._rootPath);
|
||||
dashboardWidget.register();
|
||||
|
||||
this._apiWrapper.registerCommand(constants.mlManageModelsCommand, (async () => {
|
||||
await modelManagementController.manageRegisteredModels();
|
||||
}));
|
||||
this._apiWrapper.registerCommand(constants.mlImportModelCommand, (async () => {
|
||||
await modelManagementController.registerModel(undefined);
|
||||
}));
|
||||
this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => {
|
||||
await modelManagementController.predictModel();
|
||||
}));
|
||||
this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => {
|
||||
await packageManager.installDependencies();
|
||||
}));
|
||||
this._apiWrapper.registerTaskHandler(constants.mlManagePackagesCommand, async () => {
|
||||
await packageManager.managePackages();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the package manager instance
|
||||
*/
|
||||
public getPackageManager(nbApis: nbExtensionApis.IExtensionApi): PackageManager {
|
||||
if (!this._packageManager) {
|
||||
this._packageManager = new PackageManager(this._outputChannel, this._rootPath, this._apiWrapper, this.packageManagementService, this._processService, this._config, this.httpClient);
|
||||
this._packageManager.init();
|
||||
this._packageManager.packageManageProviders.forEach(provider => {
|
||||
nbApis.registerPackageManager(provider.providerId, provider);
|
||||
});
|
||||
}
|
||||
return this._packageManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the server config manager instance
|
||||
*/
|
||||
public get packageManagementService(): PackageManagementService {
|
||||
if (!this._packageManagementService) {
|
||||
this._packageManagementService = new PackageManagementService(this._apiWrapper, this._queryRunner);
|
||||
}
|
||||
return this._packageManagementService;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the server config manager instance
|
||||
*/
|
||||
public get httpClient(): HttpClient {
|
||||
if (!this._httpClient) {
|
||||
this._httpClient = new HttpClient();
|
||||
}
|
||||
return this._httpClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* Config instance
|
||||
*/
|
||||
public get config(): Config {
|
||||
return this._config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Disposes the extension
|
||||
*/
|
||||
public dispose(): void {
|
||||
this.deactivate();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as mssql from '../../../mssql';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
|
||||
/**
|
||||
* Manage package dialog model
|
||||
*/
|
||||
export class LanguageService {
|
||||
|
||||
public connection: azdata.connection.ConnectionProfile | undefined;
|
||||
public connectionUrl: string = '';
|
||||
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _languageExtensionService: mssql.ILanguageExtensionService) {
|
||||
}
|
||||
|
||||
public async load() {
|
||||
this.connection = await this.getCurrentConnection();
|
||||
this.connectionUrl = await this.getCurrentConnectionUrl();
|
||||
}
|
||||
|
||||
public async getLanguageList(): Promise<mssql.ExternalLanguage[]> {
|
||||
if (this.connectionUrl) {
|
||||
return await this._languageExtensionService.listLanguages(this.connectionUrl);
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
public async deleteLanguage(languageName: string): Promise<void> {
|
||||
if (this.connectionUrl) {
|
||||
await this._languageExtensionService.deleteLanguage(this.connectionUrl, languageName);
|
||||
}
|
||||
}
|
||||
|
||||
public async updateLanguage(language: mssql.ExternalLanguage): Promise<void> {
|
||||
if (this.connectionUrl) {
|
||||
await this._languageExtensionService.updateLanguage(this.connectionUrl, language);
|
||||
}
|
||||
}
|
||||
|
||||
private async getCurrentConnectionUrl(): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
return await this._apiWrapper.getUriForConnection(connection.connectionId);
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
}
|
||||
33
extensions/machine-learning/src/main.ts
Normal file
33
extensions/machine-learning/src/main.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import MainController from './controllers/mainController';
|
||||
import { ApiWrapper } from './common/apiWrapper';
|
||||
import { QueryRunner } from './common/queryRunner';
|
||||
import { ProcessService } from './common/processService';
|
||||
|
||||
let controllers: MainController[] = [];
|
||||
|
||||
export async function activate(context: vscode.ExtensionContext): Promise<void> {
|
||||
|
||||
let apiWrapper = new ApiWrapper();
|
||||
let queryRunner = new QueryRunner(apiWrapper);
|
||||
let processService = new ProcessService();
|
||||
|
||||
// Start the main controller
|
||||
//
|
||||
let mainController = new MainController(context, apiWrapper, queryRunner, processService);
|
||||
controllers.push(mainController);
|
||||
context.subscriptions.push(mainController);
|
||||
|
||||
await mainController.activate();
|
||||
}
|
||||
|
||||
export function deactivate(): void {
|
||||
for (let controller of controllers) {
|
||||
controller.deactivate();
|
||||
}
|
||||
}
|
||||
62
extensions/machine-learning/src/modelManagement/artifacts.ts
Normal file
62
extensions/machine-learning/src/modelManagement/artifacts.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as msRest from '@azure/ms-rest-js';
|
||||
import * as Models from './interfaces';
|
||||
import * as Mappers from './mappers';
|
||||
import * as Parameters from './parameters';
|
||||
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
|
||||
|
||||
export class Artifacts {
|
||||
private readonly client: AzureMachineLearningWorkspacesContext;
|
||||
|
||||
constructor(client: AzureMachineLearningWorkspacesContext) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options?: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams): Promise<Models.GetArtifactContentInformation2Response>;
|
||||
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, callback: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): void;
|
||||
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams, callback: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): void;
|
||||
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options?: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams | msRest.ServiceCallback<Models.ArtifactContentInformationDto>, callback?: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): Promise<Models.GetArtifactContentInformation2Response> {
|
||||
return this.client.sendOperationRequest(
|
||||
{
|
||||
subscriptionId,
|
||||
resourceGroupName,
|
||||
workspaceName,
|
||||
origin,
|
||||
container,
|
||||
options
|
||||
},
|
||||
getArtifactContentInformation2OperationSpec,
|
||||
callback) as Promise<Models.GetArtifactContentInformation2Response>;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
const serializer = new msRest.Serializer(Mappers);
|
||||
const getArtifactContentInformation2OperationSpec: msRest.OperationSpec = {
|
||||
httpMethod: 'GET',
|
||||
path: 'artifact/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}/artifacts/contentinfo/{origin}/{container}',
|
||||
urlParameters: [
|
||||
Parameters.subscriptionId,
|
||||
Parameters.resourceGroupName,
|
||||
Parameters.workspaceName,
|
||||
Parameters.origin,
|
||||
Parameters.container,
|
||||
Parameters.apiVersion
|
||||
],
|
||||
queryParameters: [
|
||||
Parameters.projectName0,
|
||||
Parameters.path1,
|
||||
Parameters.accountName
|
||||
],
|
||||
responses: {
|
||||
200: {
|
||||
bodyMapper: Mappers.ArtifactContentInformationDto
|
||||
},
|
||||
default: {}
|
||||
},
|
||||
serializer
|
||||
};
|
||||
78
extensions/machine-learning/src/modelManagement/assets.ts
Normal file
78
extensions/machine-learning/src/modelManagement/assets.ts
Normal file
@@ -0,0 +1,78 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as msRest from '@azure/ms-rest-js';
|
||||
import * as Models from './interfaces';
|
||||
import * as Mappers from './mappers';
|
||||
import * as Parameters from './parameters';
|
||||
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
|
||||
|
||||
export class Assets {
|
||||
private readonly client: AzureMachineLearningWorkspacesContext;
|
||||
|
||||
constructor(client: AzureMachineLearningWorkspacesContext) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
queryById(
|
||||
subscriptionId: string,
|
||||
resourceGroup: string,
|
||||
workspace: string,
|
||||
id: string,
|
||||
options?: msRest.RequestOptionsBase
|
||||
): Promise<Models.AssetsQueryByIdResponse>;
|
||||
queryById(
|
||||
subscriptionId: string,
|
||||
resourceGroup: string,
|
||||
workspace: string,
|
||||
id: string,
|
||||
callback: msRest.ServiceCallback<Models.Asset>
|
||||
): void;
|
||||
queryById(
|
||||
subscriptionId: string,
|
||||
resourceGroup: string,
|
||||
workspace: string,
|
||||
id: string,
|
||||
options: msRest.RequestOptionsBase,
|
||||
callback: msRest.ServiceCallback<Models.Asset>
|
||||
): void;
|
||||
queryById(
|
||||
subscriptionId: string,
|
||||
resourceGroup: string,
|
||||
workspace: string,
|
||||
id: string,
|
||||
options?: msRest.RequestOptionsBase | msRest.ServiceCallback<Models.Asset>,
|
||||
callback?: msRest.ServiceCallback<Models.Asset>
|
||||
): Promise<Models.AssetsQueryByIdResponse> {
|
||||
return this.client.sendOperationRequest(
|
||||
{
|
||||
subscriptionId,
|
||||
resourceGroup,
|
||||
workspace,
|
||||
id,
|
||||
options
|
||||
},
|
||||
queryByIdOperationSpec,
|
||||
callback
|
||||
) as Promise<Models.AssetsQueryByIdResponse>;
|
||||
}
|
||||
}
|
||||
|
||||
const serializer = new msRest.Serializer(Mappers);
|
||||
const queryByIdOperationSpec: msRest.OperationSpec = {
|
||||
httpMethod: 'GET',
|
||||
path:
|
||||
'modelmanagement/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroup}/providers/Microsoft.MachineLearningServices/workspaces/{workspace}/assets/{id}',
|
||||
urlParameters: [Parameters.subscriptionId, Parameters.resourceGroup, Parameters.workspace, Parameters.id],
|
||||
responses: {
|
||||
200: {
|
||||
bodyMapper: Mappers.Asset
|
||||
},
|
||||
default: {
|
||||
bodyMapper: Mappers.ModelErrorResponse
|
||||
}
|
||||
},
|
||||
serializer
|
||||
};
|
||||
@@ -0,0 +1,329 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as constants from '../common/constants';
|
||||
import { azureResource } from '../typings/azure-resource';
|
||||
import { AzureMachineLearningWorkspaces } from '@azure/arm-machinelearningservices';
|
||||
import { TokenCredentials } from '@azure/ms-rest-js';
|
||||
import { WorkspaceModels } from './workspacesModels';
|
||||
import { AzureMachineLearningWorkspacesOptions, Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { WorkspaceModel, Asset, IArtifactParts } from './interfaces';
|
||||
import { Config } from '../configurations/config';
|
||||
import { Assets } from './assets';
|
||||
import * as polly from 'polly-js';
|
||||
import { Artifacts } from './artifacts';
|
||||
import { HttpClient } from '../common/httpClient';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as utils from '../common/utils';
|
||||
|
||||
/**
|
||||
* Azure Model Service
|
||||
*/
|
||||
export class AzureModelRegistryService {
|
||||
|
||||
private _amlClient: AzureMachineLearningWorkspaces | undefined;
|
||||
private _modelClient: WorkspaceModels | undefined;
|
||||
/**
|
||||
* Creates new service
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _config: Config,
|
||||
private _httpClient: HttpClient,
|
||||
private _outputChannel: vscode.OutputChannel) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of azure accounts
|
||||
*/
|
||||
public async getAccounts(): Promise<azdata.Account[]> {
|
||||
return await this._apiWrapper.getAllAccounts();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of azure subscriptions
|
||||
* @param account azure account
|
||||
*/
|
||||
public async getSubscriptions(account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[] | undefined> {
|
||||
const data = <azureResource.GetSubscriptionsResult>await this._apiWrapper.executeCommand(constants.azureSubscriptionsCommand, account, true);
|
||||
return data?.subscriptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of azure groups
|
||||
* @param account azure account
|
||||
* @param subscription azure subscription
|
||||
*/
|
||||
public async getGroups(
|
||||
account: azdata.Account | undefined,
|
||||
subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[] | undefined> {
|
||||
const data = <azureResource.GetResourceGroupsResult>await this._apiWrapper.executeCommand(constants.azureResourceGroupsCommand, account, subscription, true);
|
||||
return data?.resourceGroups;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of workspaces
|
||||
* @param account azure account
|
||||
* @param subscription azure subscription
|
||||
* @param resourceGroup azure resource group
|
||||
*/
|
||||
public async getWorkspaces(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
|
||||
return await this.fetchWorkspaces(account, subscription, resourceGroup);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of models
|
||||
* @param account azure account
|
||||
* @param subscription azure subscription
|
||||
* @param resourceGroup azure resource group
|
||||
* @param workspace azure workspace
|
||||
*/
|
||||
public async getModels(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource,
|
||||
workspace: Workspace): Promise<WorkspaceModel[] | undefined> {
|
||||
return await this.fetchModels(account, subscription, resourceGroup, workspace);
|
||||
}
|
||||
|
||||
/**
|
||||
* Download an azure model to a temporary location
|
||||
* @param account azure account
|
||||
* @param subscription azure subscription
|
||||
* @param resourceGroup azure resource group
|
||||
* @param workspace azure workspace
|
||||
* @param model azure model
|
||||
*/
|
||||
public async downloadModel(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource,
|
||||
workspace: Workspace,
|
||||
model: WorkspaceModel): Promise<string> {
|
||||
let downloadedFilePath: string = '';
|
||||
|
||||
for (const tenant of account.properties.tenants) {
|
||||
try {
|
||||
const downloadUrls = await this.getAssetArtifactsDownloadLinks(account, subscription, resourceGroup, workspace, model, tenant);
|
||||
if (downloadUrls && downloadUrls.length > 0) {
|
||||
downloadedFilePath = await this.execDownloadArtifactTask(downloadUrls[0]);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
}
|
||||
}
|
||||
return downloadedFilePath;
|
||||
}
|
||||
|
||||
public set AzureMachineLearningClient(value: AzureMachineLearningWorkspaces) {
|
||||
this._amlClient = value;
|
||||
}
|
||||
|
||||
public set ModelClient(value: WorkspaceModels) {
|
||||
this._modelClient = value;
|
||||
}
|
||||
|
||||
public async signInToAzure(): Promise<void> {
|
||||
await this._apiWrapper.executeCommand(constants.signInToAzureCommand);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the background task to download the artifact
|
||||
*/
|
||||
private async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
|
||||
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
|
||||
return results && results.length > 0 ? results[0] : constants.noResultError;
|
||||
}
|
||||
|
||||
private async downloadArtifact(downloadUrl: string): Promise<string> {
|
||||
let tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await this._httpClient.download(downloadUrl, tempFilePath, this._outputChannel);
|
||||
return tempFilePath;
|
||||
}
|
||||
|
||||
private async fetchWorkspaces(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
|
||||
let resources: Workspace[] = [];
|
||||
|
||||
try {
|
||||
for (const tenant of account.properties.tenants) {
|
||||
const client = await this.getAmlClient(account, subscription, tenant);
|
||||
let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription();
|
||||
if (result) {
|
||||
resources.push(...result);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
}
|
||||
return resources;
|
||||
}
|
||||
|
||||
private async fetchModels(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource,
|
||||
workspace: Workspace): Promise<WorkspaceModel[]> {
|
||||
let resources: WorkspaceModel[] = [];
|
||||
|
||||
for (const tenant of account.properties.tenants) {
|
||||
try {
|
||||
let options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
|
||||
};
|
||||
const client = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
let modelsClient = this.getModelClient(client);
|
||||
resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || ''));
|
||||
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
}
|
||||
}
|
||||
|
||||
return resources;
|
||||
}
|
||||
|
||||
private async fetchModelAsset(
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource,
|
||||
workspace: Workspace,
|
||||
model: WorkspaceModel,
|
||||
client: AzureMachineLearningWorkspaces): Promise<Asset> {
|
||||
|
||||
const modelId = this.getModelId(model);
|
||||
if (modelId) {
|
||||
let modelsClient = new Assets(client);
|
||||
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
|
||||
} else {
|
||||
throw Error(constants.invalidModelIdError(model.url));
|
||||
}
|
||||
}
|
||||
|
||||
private async getAssetArtifactsDownloadLinks(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
resourceGroup: azureResource.AzureResource,
|
||||
workspace: Workspace,
|
||||
model: WorkspaceModel,
|
||||
tenant: any): Promise<string[]> {
|
||||
let options: AzureMachineLearningWorkspacesOptions = {
|
||||
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
|
||||
};
|
||||
const modelManagementClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient);
|
||||
options.baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
|
||||
const experienceClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
|
||||
const artifactClient = new Artifacts(experienceClient);
|
||||
let downloadLinks: string[] = [];
|
||||
if (asset && asset.artifacts) {
|
||||
const downloadLinkPromises: Array<Promise<string>> = [];
|
||||
for (const artifact of asset.artifacts) {
|
||||
const parts = artifact.id
|
||||
? this.getPartsFromAssetIdOrPrefix(artifact.id)
|
||||
: this.getPartsFromAssetIdOrPrefix(artifact.prefix);
|
||||
|
||||
if (parts) {
|
||||
const promise = polly()
|
||||
.waitAndRetry(3)
|
||||
.executeForPromise(
|
||||
async (): Promise<string> => {
|
||||
const artifact = await artifactClient.getArtifactContentInformation2(
|
||||
experienceClient.subscriptionId,
|
||||
resourceGroup.name,
|
||||
workspace.name || '',
|
||||
parts.origin,
|
||||
parts.container,
|
||||
{ path: parts.path }
|
||||
);
|
||||
if (artifact) {
|
||||
return artifact.contentUri || '';
|
||||
} else {
|
||||
return Promise.reject();
|
||||
}
|
||||
}
|
||||
);
|
||||
downloadLinkPromises.push(promise);
|
||||
}
|
||||
}
|
||||
try {
|
||||
downloadLinks = await Promise.all(downloadLinkPromises);
|
||||
} catch (rejectedPromiseError) {
|
||||
return rejectedPromiseError;
|
||||
}
|
||||
return downloadLinks;
|
||||
|
||||
} else {
|
||||
throw Error(constants.noArtifactError(model.url));
|
||||
}
|
||||
}
|
||||
|
||||
private getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
|
||||
const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/;
|
||||
if (idOrPrefix) {
|
||||
const parts = artifactRegex.exec(idOrPrefix);
|
||||
if (parts && parts.length === 4) {
|
||||
return {
|
||||
origin: parts[1],
|
||||
container: parts[2],
|
||||
path: parts[3]
|
||||
};
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private getBaseUrl(workspace: Workspace, server: string): string {
|
||||
let baseUri = `https://${workspace.location}.${server}`;
|
||||
if (workspace.location === 'chinaeast2') {
|
||||
baseUri = `https://${workspace.location}.${server}`;
|
||||
}
|
||||
return baseUri;
|
||||
}
|
||||
|
||||
private getModelClient(amlClient: AzureMachineLearningWorkspaces) {
|
||||
return this._modelClient ?? new WorkspaceModels(amlClient);
|
||||
}
|
||||
|
||||
private async getAmlClient(
|
||||
account: azdata.Account,
|
||||
subscription: azureResource.AzureResourceSubscription,
|
||||
tenant: any,
|
||||
options: AzureMachineLearningWorkspacesOptions | undefined = undefined,
|
||||
apiVersion: string | undefined = undefined): Promise<AzureMachineLearningWorkspaces> {
|
||||
if (this._amlClient) {
|
||||
return this._amlClient;
|
||||
} else {
|
||||
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
|
||||
let token: string = '';
|
||||
let tokenType: string | undefined = undefined;
|
||||
if (tokens && tenant.id in tokens) {
|
||||
const tokenForId = tokens[tenant.id];
|
||||
if (tokenForId) {
|
||||
token = tokenForId.token;
|
||||
tokenType = tokenForId.tokenType;
|
||||
}
|
||||
}
|
||||
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
|
||||
if (apiVersion) {
|
||||
client.apiVersion = apiVersion;
|
||||
}
|
||||
return client;
|
||||
}
|
||||
}
|
||||
|
||||
private getModelId(model: WorkspaceModel): string {
|
||||
const amlAssetRegex = /^aml:\/\/asset\/(.+)$/;
|
||||
const id = model ? amlAssetRegex.exec(model.url || '') : undefined;
|
||||
return id && id.length === 2 ? id[1] : '';
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import { Config } from '../configurations/config';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { ImportedModel, ImportedModelDetails, ModelParameters } from './interfaces';
|
||||
import { ModelPythonClient } from './modelPythonClient';
|
||||
import * as constants from '../common/constants';
|
||||
import * as queries from './queries';
|
||||
import { DatabaseTable } from '../prediction/interfaces';
|
||||
import { ModelConfigRecent } from './modelConfigRecent';
|
||||
|
||||
/**
|
||||
* Service to deployed models
|
||||
*/
|
||||
export class DeployedModelService {
|
||||
|
||||
/**
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _config: Config,
|
||||
private _queryRunner: QueryRunner,
|
||||
private _modelClient: ModelPythonClient,
|
||||
private _recentModelService: ModelConfigRecent) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns deployed models
|
||||
*/
|
||||
public async getDeployedModels(table: DatabaseTable): Promise<ImportedModel[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: ImportedModel[] = [];
|
||||
if (!table.databaseName || !table.tableName || !table.schema) {
|
||||
return [];
|
||||
}
|
||||
if (connection) {
|
||||
const query = queries.getDeployedModelsQuery(table);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push(this.loadModelData(row, table));
|
||||
});
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.noConnectionError);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads model
|
||||
* @param model model object
|
||||
*/
|
||||
public async downloadModel(model: ImportedModel): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
const query = queries.getModelContentQuery(model);
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
const content = result.rows[0][0].displayValue;
|
||||
return await utils.writeFileFromHex(content);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToSelectError);
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.noConnectionError);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(filePath: string): Promise<ModelParameters> {
|
||||
return await this._modelClient.loadModelParameters(filePath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Deploys local model
|
||||
* @param filePath model file path
|
||||
* @param details model details
|
||||
*/
|
||||
public async deployLocalModel(filePath: string, details: ImportedModelDetails | undefined, table: DatabaseTable) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection && table.databaseName) {
|
||||
|
||||
await this.configureImport(connection, table);
|
||||
let currentModels = await this.getDeployedModels(table);
|
||||
const content = await utils.readFileInHex(filePath);
|
||||
let modelToAdd: ImportedModel = Object.assign({}, {
|
||||
id: 0,
|
||||
content: content,
|
||||
table: table
|
||||
}, details);
|
||||
await this._queryRunner.runWithDatabaseChange(connection, queries.getInsertModelQuery(modelToAdd, table), table.databaseName);
|
||||
|
||||
let updatedModels = await this.getDeployedModels(table);
|
||||
if (updatedModels.length < currentModels.length + 1) {
|
||||
throw Error(constants.importModelFailedError(details?.modelName, filePath));
|
||||
}
|
||||
|
||||
} else {
|
||||
throw new Error(constants.noConnectionError);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a model
|
||||
*/
|
||||
public async updateModel(model: ImportedModel) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection && model && model.table && model.table.databaseName) {
|
||||
await this._queryRunner.runWithDatabaseChange(connection, queries.getUpdateModelQuery(model), model.table.databaseName);
|
||||
} else {
|
||||
throw new Error(constants.noConnectionError);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a model
|
||||
*/
|
||||
public async deleteModel(model: ImportedModel) {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection && model && model.table && model.table.databaseName) {
|
||||
await this._queryRunner.runWithDatabaseChange(connection, queries.getDeleteModelQuery(model), model.table.databaseName);
|
||||
} else {
|
||||
throw new Error(constants.noConnectionError);
|
||||
}
|
||||
}
|
||||
|
||||
public async configureImport(connection: azdata.connection.ConnectionProfile, table: DatabaseTable) {
|
||||
if (connection && table.databaseName) {
|
||||
let query = queries.getDatabaseConfigureQuery(table);
|
||||
await this._queryRunner.safeRunQuery(connection, query);
|
||||
|
||||
query = queries.getConfigureTableQuery(table);
|
||||
await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies if the given table name is valid to be used as import table. If table doesn't exist returns true to create new table
|
||||
* Otherwise verifies the schema and returns true if the schema is supported
|
||||
* @param connection database connection
|
||||
* @param table config table name
|
||||
*/
|
||||
public async verifyConfigTable(table: DatabaseTable): Promise<boolean> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection && table.databaseName) {
|
||||
let databases = await this._apiWrapper.listDatabases(connection.connectionId);
|
||||
|
||||
// If database exist verify the table schema
|
||||
//
|
||||
if ((await databases).find(x => x === table.databaseName)) {
|
||||
const query = queries.getConfigTableVerificationQuery(table);
|
||||
const result = await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
|
||||
return result !== undefined && result.rows.length > 0 && result.rows[0][0].displayValue === '1';
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
throw new Error(constants.noConnectionError);
|
||||
}
|
||||
}
|
||||
|
||||
public async getRecentImportTable(): Promise<DatabaseTable> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let table: DatabaseTable | undefined;
|
||||
if (connection) {
|
||||
table = this._recentModelService.getModelTable(connection);
|
||||
if (!table) {
|
||||
table = {
|
||||
databaseName: connection.databaseName ?? 'master',
|
||||
tableName: this._config.registeredModelTableName,
|
||||
schema: this._config.registeredModelTableSchemaName
|
||||
};
|
||||
}
|
||||
} else {
|
||||
throw new Error(constants.noConnectionError);
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
public async storeRecentImportTable(importTable: DatabaseTable): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
this._recentModelService.storeModelTable(connection, importTable);
|
||||
} else {
|
||||
throw new Error(constants.noConnectionError);
|
||||
}
|
||||
}
|
||||
|
||||
private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): ImportedModel {
|
||||
return {
|
||||
id: +row[0].displayValue,
|
||||
modelName: row[1].displayValue,
|
||||
description: row[2].displayValue,
|
||||
version: row[3].displayValue,
|
||||
created: row[4].displayValue,
|
||||
framework: row[5].displayValue,
|
||||
frameworkVersion: row[6].displayValue,
|
||||
deploymentTime: row[7].displayValue,
|
||||
deployedBy: row[8].displayValue,
|
||||
runId: row[9].displayValue,
|
||||
table: table
|
||||
};
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
}
|
||||
241
extensions/machine-learning/src/modelManagement/interfaces.ts
Normal file
241
extensions/machine-learning/src/modelManagement/interfaces.ts
Normal file
@@ -0,0 +1,241 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as msRest from '@azure/ms-rest-js';
|
||||
import { Resource } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { DatabaseTable } from '../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* An interface representing ListWorkspaceModelResult.
|
||||
*/
|
||||
export interface ListWorkspaceModelsResult extends Array<WorkspaceModel> {
|
||||
}
|
||||
|
||||
/**
|
||||
* An interface representing Workspace model
|
||||
*/
|
||||
export interface WorkspaceModel extends Resource {
|
||||
framework?: string;
|
||||
frameworkVersion?: string;
|
||||
createdBy?: string;
|
||||
createdTime?: string;
|
||||
experimentName?: string;
|
||||
outputsSchema?: Array<string>;
|
||||
url?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* An interface representing Workspace model list response
|
||||
*/
|
||||
export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
|
||||
/**
|
||||
* The underlying HTTP response.
|
||||
*/
|
||||
_response: msRest.HttpResponse & {
|
||||
/**
|
||||
* The response body as text (string format)
|
||||
*/
|
||||
bodyAsText: string;
|
||||
|
||||
/**
|
||||
* The response body as parsed JSON or XML
|
||||
*/
|
||||
parsedBody: ListWorkspaceModelsResult;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* An interface representing imported model
|
||||
*/
|
||||
export interface ImportedModel extends ImportedModelDetails {
|
||||
id: number;
|
||||
content?: string;
|
||||
table: DatabaseTable;
|
||||
}
|
||||
|
||||
export interface ModelParameter {
|
||||
name: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
export interface ModelParameters {
|
||||
inputs: ModelParameter[],
|
||||
outputs: ModelParameter[]
|
||||
}
|
||||
|
||||
/**
|
||||
* An interface representing imported model
|
||||
*/
|
||||
export interface ImportedModelDetails {
|
||||
modelName: string;
|
||||
created?: string;
|
||||
deploymentTime?: string;
|
||||
version?: string;
|
||||
description?: string;
|
||||
fileName?: string;
|
||||
framework?: string;
|
||||
frameworkVersion?: string;
|
||||
runId?: string;
|
||||
deployedBy?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* The Artifact definition.
|
||||
*/
|
||||
export interface ArtifactDetails {
|
||||
/**
|
||||
* The Artifact Id.
|
||||
*/
|
||||
id?: string;
|
||||
/**
|
||||
* The Artifact prefix.
|
||||
*/
|
||||
prefix?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* @interface
|
||||
* An interface representing Asset.
|
||||
* The Asset definition.
|
||||
*
|
||||
*/
|
||||
export interface Asset {
|
||||
/**
|
||||
* @member {string} [id] The Asset Id.
|
||||
*/
|
||||
id?: string;
|
||||
/**
|
||||
* @member {string} [name] The name of the Asset.
|
||||
*/
|
||||
name?: string;
|
||||
/**
|
||||
* @member {string} [description] The Asset description.
|
||||
*/
|
||||
description?: string;
|
||||
/**
|
||||
* @member {ArtifactDetails[]} [artifacts] A list of child artifacts.
|
||||
*/
|
||||
artifacts?: ArtifactDetails[];
|
||||
/**
|
||||
* @member {string[]} [tags] The Asset tag list.
|
||||
*/
|
||||
tags?: string[];
|
||||
/**
|
||||
* @member {{ [propertyName: string]: string }} [kvTags] The Asset tag
|
||||
* dictionary. Tags are mutable.
|
||||
*/
|
||||
kvTags?: { [propertyName: string]: string };
|
||||
/**
|
||||
* @member {{ [propertyName: string]: string }} [properties] The Asset
|
||||
* property dictionary. Properties are immutable.
|
||||
*/
|
||||
properties?: { [propertyName: string]: string };
|
||||
/**
|
||||
* @member {string} [runid] The RunId associated with this Asset.
|
||||
*/
|
||||
runid?: string;
|
||||
/**
|
||||
* @member {string} [projectid] The project Id.
|
||||
*/
|
||||
projectid?: string;
|
||||
/**
|
||||
* @member {{ [propertyName: string]: string }} [meta] A dictionary
|
||||
* containing metadata about the Asset.
|
||||
*/
|
||||
meta?: { [propertyName: string]: string };
|
||||
/**
|
||||
* @member {Date} [createdTime] The time the Asset was created in UTC.
|
||||
*/
|
||||
createdTime?: Date;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Contains response data for the queryById operation.
|
||||
*/
|
||||
export type AssetsQueryByIdResponse = Asset & {
|
||||
/**
|
||||
* The underlying HTTP response.
|
||||
*/
|
||||
_response: msRest.HttpResponse & {
|
||||
/**
|
||||
* The response body as text (string format)
|
||||
*/
|
||||
bodyAsText: string;
|
||||
/**
|
||||
* The response body as parsed JSON or XML
|
||||
*/
|
||||
parsedBody: Asset;
|
||||
};
|
||||
};
|
||||
|
||||
export interface IArtifactParts {
|
||||
origin: string;
|
||||
container: string;
|
||||
path: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* @interface
|
||||
* An interface representing ArtifactContentInformationDto.
|
||||
*/
|
||||
export interface ArtifactContentInformationDto {
|
||||
/**
|
||||
* @member {string} [contentUri]
|
||||
*/
|
||||
contentUri?: string;
|
||||
/**
|
||||
* @member {string} [origin]
|
||||
*/
|
||||
origin?: string;
|
||||
/**
|
||||
* @member {string} [container]
|
||||
*/
|
||||
container?: string;
|
||||
/**
|
||||
* @member {string} [path]
|
||||
*/
|
||||
path?: string;
|
||||
}
|
||||
/**
|
||||
* Contains response data for the getArtifactContentInformation2 operation.
|
||||
*/
|
||||
export type GetArtifactContentInformation2Response = ArtifactContentInformationDto & {
|
||||
/**
|
||||
* The underlying HTTP response.
|
||||
*/
|
||||
_response: msRest.HttpResponse & {
|
||||
/**
|
||||
* The response body as text (string format)
|
||||
*/
|
||||
bodyAsText: string;
|
||||
/**
|
||||
* The response body as parsed JSON or XML
|
||||
*/
|
||||
parsedBody: ArtifactContentInformationDto;
|
||||
};
|
||||
};
|
||||
/**
|
||||
* @interface
|
||||
* An interface representing ArtifactAPIGetArtifactContentInformation2OptionalParams.
|
||||
* Optional Parameters.
|
||||
*
|
||||
* @extends RequestOptionsBase
|
||||
*/
|
||||
export interface ArtifactAPIGetArtifactContentInformation2OptionalParams extends msRest.RequestOptionsBase {
|
||||
/**
|
||||
* @member {string} [projectName]
|
||||
*/
|
||||
projectName?: string;
|
||||
/**
|
||||
* @member {string} [path]
|
||||
*/
|
||||
path?: string;
|
||||
/**
|
||||
* @member {string} [accountName]
|
||||
*/
|
||||
accountName?: string;
|
||||
}
|
||||
|
||||
320
extensions/machine-learning/src/modelManagement/mappers.ts
Normal file
320
extensions/machine-learning/src/modelManagement/mappers.ts
Normal file
@@ -0,0 +1,320 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as msRest from '@azure/ms-rest-js';
|
||||
|
||||
export const Resource: msRest.CompositeMapper = {
|
||||
serializedName: 'Resource',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'Resource',
|
||||
modelProperties: {
|
||||
id: {
|
||||
readOnly: true,
|
||||
serializedName: 'id',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
name: {
|
||||
readOnly: true,
|
||||
serializedName: 'name',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
identity: {
|
||||
readOnly: true,
|
||||
serializedName: 'identity',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'Identity'
|
||||
}
|
||||
},
|
||||
location: {
|
||||
serializedName: 'location',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
type: {
|
||||
readOnly: true,
|
||||
serializedName: 'type',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
tags: {
|
||||
serializedName: 'tags',
|
||||
type: {
|
||||
name: 'Dictionary',
|
||||
value: {
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const ListWorkspaceModelsResult: msRest.CompositeMapper = {
|
||||
serializedName: 'ListWorkspaceModelsResult',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'ListWorkspaceModelsResult',
|
||||
modelProperties: {
|
||||
value: {
|
||||
serializedName: '',
|
||||
type: {
|
||||
name: 'Sequence',
|
||||
element: {
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'WorkspaceModel'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
nextLink: {
|
||||
serializedName: 'nextLink',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const WorkspaceModel: msRest.CompositeMapper = {
|
||||
serializedName: 'WorkspaceModel',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'WorkspaceModel',
|
||||
modelProperties: {
|
||||
...Resource.type.modelProperties,
|
||||
framework: {
|
||||
readOnly: true,
|
||||
serializedName: 'framework',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const MachineLearningServiceError: msRest.CompositeMapper = {
|
||||
serializedName: 'MachineLearningServiceError',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'MachineLearningServiceError',
|
||||
modelProperties: {
|
||||
error: {
|
||||
readOnly: true,
|
||||
serializedName: 'error',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'ErrorResponse'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
export const ModelErrorResponse: msRest.CompositeMapper = {
|
||||
serializedName: 'ModelErrorResponse',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'ModelErrorResponse',
|
||||
modelProperties: {
|
||||
code: {
|
||||
serializedName: 'code',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
statusCode: {
|
||||
serializedName: 'statusCode',
|
||||
type: {
|
||||
name: 'Number'
|
||||
}
|
||||
},
|
||||
message: {
|
||||
serializedName: 'message',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
details: {
|
||||
serializedName: 'details',
|
||||
type: {
|
||||
name: 'Sequence',
|
||||
element: {
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'ErrorDetails'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
export const ArtifactDetails: msRest.CompositeMapper = {
|
||||
serializedName: 'ArtifactDetails',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'ArtifactDetails',
|
||||
modelProperties: {
|
||||
id: {
|
||||
serializedName: 'id',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
prefix: {
|
||||
serializedName: 'prefix',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
export const Asset: msRest.CompositeMapper = {
|
||||
serializedName: 'Asset',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'Asset',
|
||||
modelProperties: {
|
||||
id: {
|
||||
serializedName: 'id',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
name: {
|
||||
serializedName: 'name',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
description: {
|
||||
serializedName: 'description',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
artifacts: {
|
||||
serializedName: 'artifacts',
|
||||
type: {
|
||||
name: 'Sequence',
|
||||
element: {
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'ArtifactDetails'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
tags: {
|
||||
serializedName: 'tags',
|
||||
type: {
|
||||
name: 'Sequence',
|
||||
element: {
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
kvTags: {
|
||||
serializedName: 'kvTags',
|
||||
type: {
|
||||
name: 'Dictionary',
|
||||
value: {
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
properties: {
|
||||
serializedName: 'properties',
|
||||
type: {
|
||||
name: 'Dictionary',
|
||||
value: {
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
runid: {
|
||||
serializedName: 'runid',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
projectid: {
|
||||
serializedName: 'projectid',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
meta: {
|
||||
serializedName: 'meta',
|
||||
type: {
|
||||
name: 'Dictionary',
|
||||
value: {
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
createdTime: {
|
||||
serializedName: 'createdTime',
|
||||
type: {
|
||||
name: 'DateTime'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
export const ArtifactContentInformationDto: msRest.CompositeMapper = {
|
||||
serializedName: 'ArtifactContentInformationDto',
|
||||
type: {
|
||||
name: 'Composite',
|
||||
className: 'ArtifactContentInformationDto',
|
||||
modelProperties: {
|
||||
contentUri: {
|
||||
serializedName: 'contentUri',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
origin: {
|
||||
serializedName: 'origin',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
container: {
|
||||
serializedName: 'container',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
},
|
||||
path: {
|
||||
serializedName: 'path',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,30 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import { DatabaseTable } from '../prediction/interfaces';
|
||||
|
||||
const TableConfigName = 'MLS_ModelTableConfigName';
|
||||
|
||||
export class ModelConfigRecent {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor(private _memento: vscode.Memento) {
|
||||
}
|
||||
|
||||
public getModelTable(connection: azdata.connection.ConnectionProfile): DatabaseTable | undefined {
|
||||
return this._memento.get<DatabaseTable>(this.getKey(connection));
|
||||
}
|
||||
|
||||
public storeModelTable(connection: azdata.connection.ConnectionProfile, table: DatabaseTable): void {
|
||||
this._memento.update(this.getKey(connection), table);
|
||||
}
|
||||
|
||||
private getKey(connection: azdata.connection.ConnectionProfile): string {
|
||||
return `${TableConfigName}_${connection.serverName}`;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { ProcessService } from '../common/processService';
|
||||
import { Config } from '../configurations/config';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as utils from '../common/utils';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as constants from '../common/constants';
|
||||
import * as os from 'os';
|
||||
import { ModelParameters } from './interfaces';
|
||||
|
||||
/**
|
||||
* Python client for ONNX models
|
||||
*/
|
||||
export class ModelPythonClient {
|
||||
|
||||
/**
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Deploys models in the SQL database using mlflow
|
||||
* @param connection
|
||||
* @param modelPath
|
||||
*/
|
||||
public async deployModel(connection: azdata.connection.ConnectionProfile, modelPath: string): Promise<void> {
|
||||
await this.installDependencies();
|
||||
await this.executeDeployScripts(connection, modelPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for python client
|
||||
*/
|
||||
private async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [
|
||||
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelPath Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
|
||||
await this.installDependencies();
|
||||
return await this.executeModelParametersScripts(modelPath);
|
||||
}
|
||||
|
||||
private async executeModelParametersScripts(modelFolderPath: string): Promise<ModelParameters> {
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
let scripts: string[] = [
|
||||
'import onnx',
|
||||
'import json',
|
||||
`onnx_model_path = '${modelFolderPath}'`,
|
||||
`onnx_model = onnx.load_model(onnx_model_path)`,
|
||||
`type_map = {
|
||||
onnx.TensorProto.DataType.FLOAT: 'real',
|
||||
onnx.TensorProto.DataType.UINT8: 'tinyint',
|
||||
onnx.TensorProto.DataType.INT16: 'smallint',
|
||||
onnx.TensorProto.DataType.INT32: 'int',
|
||||
onnx.TensorProto.DataType.INT64: 'bigint',
|
||||
onnx.TensorProto.DataType.STRING: 'varchar(MAX)',
|
||||
onnx.TensorProto.DataType.DOUBLE: 'float'}`,
|
||||
`parameters = {
|
||||
"inputs": [],
|
||||
"outputs": []
|
||||
}`,
|
||||
`def addParameters(list, paramType):
|
||||
for id, p in enumerate(list):
|
||||
p_type = ''
|
||||
|
||||
if p.type.tensor_type.elem_type in type_map:
|
||||
p_type = type_map[p.type.tensor_type.elem_type]
|
||||
|
||||
parameters[paramType].append({
|
||||
'name': p.name,
|
||||
'type': p_type
|
||||
})`,
|
||||
|
||||
'addParameters(onnx_model.graph.input, "inputs")',
|
||||
'addParameters(onnx_model.graph.output, "outputs")',
|
||||
'print(json.dumps(parameters))'
|
||||
];
|
||||
let pythonExecutable = this._config.pythonExecutable;
|
||||
let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
|
||||
let parametersJson = JSON.parse(output);
|
||||
return Object.assign({}, parametersJson);
|
||||
}
|
||||
|
||||
private async executeDeployScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
|
||||
let home = utils.makeLinuxPath(os.homedir());
|
||||
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
|
||||
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
|
||||
if (connection) {
|
||||
let server = connection.serverName;
|
||||
|
||||
const experimentId = `ads_ml_experiment_${UUID.generateUuid()}`;
|
||||
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
|
||||
let scripts: string[] = [
|
||||
'import mlflow.onnx',
|
||||
`tracking_uri = "file://${home}/mlruns"`,
|
||||
'print(tracking_uri)',
|
||||
'import onnx',
|
||||
'from mlflow.tracking.client import MlflowClient',
|
||||
`onx = onnx.load("${modelFolderPath}")`,
|
||||
`mlflow.set_tracking_uri(tracking_uri)`,
|
||||
'client = MlflowClient()',
|
||||
`exp_name = "${experimentId}"`,
|
||||
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
|
||||
'client.create_experiment(exp_name, artifact_location=db_uri_artifact)',
|
||||
'mlflow.set_experiment(exp_name)',
|
||||
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'
|
||||
];
|
||||
let pythonExecutable = this._config.pythonExecutable;
|
||||
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
|
||||
}
|
||||
}
|
||||
}
|
||||
143
extensions/machine-learning/src/modelManagement/parameters.ts
Normal file
143
extensions/machine-learning/src/modelManagement/parameters.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as msRest from '@azure/ms-rest-js';
|
||||
|
||||
export const subscriptionId: msRest.OperationURLParameter = {
|
||||
parameterPath: 'subscriptionId',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'subscriptionId',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const resourceGroupName: msRest.OperationURLParameter = {
|
||||
parameterPath: 'resourceGroupName',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'resourceGroupName',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const workspaceName: msRest.OperationURLParameter = {
|
||||
parameterPath: 'workspaceName',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'workspaceName',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const workspace: msRest.OperationURLParameter = {
|
||||
parameterPath: 'workspace',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'workspace',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const resourceGroup: msRest.OperationURLParameter = {
|
||||
parameterPath: 'resourceGroup',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'resourceGroup',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const id: msRest.OperationURLParameter = {
|
||||
parameterPath: 'id',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'id',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const acceptLanguage: msRest.OperationParameter = {
|
||||
parameterPath: 'acceptLanguage',
|
||||
mapper: {
|
||||
serializedName: 'accept-language',
|
||||
defaultValue: 'en-US',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const apiVersion: msRest.OperationQueryParameter = {
|
||||
parameterPath: 'apiVersion',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'api-version',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const origin: msRest.OperationURLParameter = {
|
||||
parameterPath: 'origin',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'origin',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const container: msRest.OperationURLParameter = {
|
||||
parameterPath: 'container',
|
||||
mapper: {
|
||||
required: true,
|
||||
serializedName: 'container',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const projectName0: msRest.OperationQueryParameter = {
|
||||
parameterPath: [
|
||||
'options',
|
||||
'projectName'
|
||||
],
|
||||
mapper: {
|
||||
serializedName: 'projectName',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const path1: msRest.OperationQueryParameter = {
|
||||
parameterPath: [
|
||||
'options',
|
||||
'path'
|
||||
],
|
||||
mapper: {
|
||||
serializedName: 'path',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
export const accountName: msRest.OperationQueryParameter = {
|
||||
parameterPath: [
|
||||
'options',
|
||||
'accountName'
|
||||
],
|
||||
mapper: {
|
||||
serializedName: 'accountName',
|
||||
type: {
|
||||
name: 'String'
|
||||
}
|
||||
}
|
||||
};
|
||||
195
extensions/machine-learning/src/modelManagement/queries.ts
Normal file
195
extensions/machine-learning/src/modelManagement/queries.ts
Normal file
@@ -0,0 +1,195 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as utils from '../common/utils';
|
||||
import { DatabaseTable } from '../prediction/interfaces';
|
||||
import { ImportedModel } from './interfaces';
|
||||
|
||||
export function getDatabaseConfigureQuery(configTable: DatabaseTable): string {
|
||||
return `
|
||||
IF NOT EXISTS (
|
||||
SELECT name
|
||||
FROM sys.databases
|
||||
WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}'
|
||||
)
|
||||
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}]
|
||||
`;
|
||||
}
|
||||
|
||||
export function getDeployedModelsQuery(table: DatabaseTable): string {
|
||||
return `
|
||||
${selectQuery}
|
||||
FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')}
|
||||
WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
|
||||
ORDER BY model_id
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies config table has the expected schema
|
||||
* @param databaseName
|
||||
* @param tableName
|
||||
*/
|
||||
export function getConfigTableVerificationQuery(table: DatabaseTable): string {
|
||||
let tableName = table.tableName;
|
||||
let schemaName = table.schema;
|
||||
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
|
||||
|
||||
return `
|
||||
IF NOT EXISTS (
|
||||
SELECT name
|
||||
FROM sys.databases
|
||||
WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}'
|
||||
)
|
||||
BEGIN
|
||||
SELECT 1
|
||||
END
|
||||
ELSE
|
||||
BEGIN
|
||||
USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}]
|
||||
IF EXISTS
|
||||
( SELECT t.name, s.name
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
||||
)
|
||||
BEGIN
|
||||
IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_name')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_id')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_description')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework_version')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_version')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_creation_time')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_deployment_time')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='deployed_by')
|
||||
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='run_id')
|
||||
BEGIN
|
||||
SELECT 1
|
||||
END
|
||||
ELSE
|
||||
BEGIN
|
||||
SELECT 0
|
||||
END
|
||||
END
|
||||
ELSE
|
||||
SELECT 1
|
||||
END
|
||||
`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the import table if doesn't exist
|
||||
*/
|
||||
export function getConfigureTableQuery(table: DatabaseTable): string {
|
||||
let tableName = table.tableName;
|
||||
let schemaName = table.schema;
|
||||
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
|
||||
|
||||
return `
|
||||
IF NOT EXISTS
|
||||
( SELECT t.name, s.name
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
|
||||
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
|
||||
)
|
||||
BEGIN
|
||||
CREATE TABLE ${twoPartTableName}(
|
||||
[model_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[model_name] [varchar](256) NOT NULL,
|
||||
[model_framework] [varchar](256) NULL,
|
||||
[model_framework_version] [varchar](256) NULL,
|
||||
[model] [varbinary](max) NOT NULL,
|
||||
[model_version] [varchar](256) NULL,
|
||||
[model_creation_time] [datetime2] NULL,
|
||||
[model_deployment_time] [datetime2] NULL,
|
||||
[deployed_by] [int] NULL,
|
||||
[model_description] [varchar](256) NULL,
|
||||
[run_id] [varchar](256) NULL,
|
||||
CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_models_pk] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[model_id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||
ALTER TABLE ${twoPartTableName} ADD CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
|
||||
END
|
||||
`;
|
||||
}
|
||||
|
||||
export function getInsertModelQuery(model: ImportedModel, table: DatabaseTable): string {
|
||||
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
|
||||
const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '');
|
||||
let updateScript = `
|
||||
INSERT INTO ${twoPartTableName}
|
||||
(model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id)
|
||||
VALUES (
|
||||
'${utils.doubleEscapeSingleQuotes(model.modelName || '')}',
|
||||
${utils.doubleEscapeSingleQuotes(model.content || '')},
|
||||
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
'${utils.doubleEscapeSingleQuotes(model.description || '')}',
|
||||
'${utils.doubleEscapeSingleQuotes(model.created || '')}',
|
||||
'${utils.doubleEscapeSingleQuotes(model.framework || '')}',
|
||||
'${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
|
||||
'${utils.doubleEscapeSingleQuotes(model.runId || '')}')
|
||||
`;
|
||||
|
||||
return `
|
||||
${updateScript}
|
||||
${selectQuery}
|
||||
FROM ${threePartTableName}
|
||||
WHERE model_id = SCOPE_IDENTITY();
|
||||
`;
|
||||
}
|
||||
|
||||
export function getModelContentQuery(model: ImportedModel): string {
|
||||
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
|
||||
return `
|
||||
SELECT model
|
||||
FROM ${threePartTableName}
|
||||
WHERE model_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
|
||||
export function getUpdateModelQuery(model: ImportedModel): string {
|
||||
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || '');
|
||||
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
|
||||
let updateScript = `
|
||||
UPDATE ${twoPartTableName}
|
||||
SET
|
||||
model_name = '${utils.doubleEscapeSingleQuotes(model.modelName || '')}',
|
||||
model_version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
|
||||
model_description = '${utils.doubleEscapeSingleQuotes(model.description || '')}',
|
||||
model_creation_time = '${utils.doubleEscapeSingleQuotes(model.created || '')}',
|
||||
model_framework = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
|
||||
model_framework_version = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
|
||||
run_id = '${utils.doubleEscapeSingleQuotes(model.runId || '')}'
|
||||
WHERE model_id = ${model.id}`;
|
||||
|
||||
return `
|
||||
${updateScript}
|
||||
${selectQuery}
|
||||
FROM ${threePartTableName}
|
||||
WHERE model_id = ${model.id};
|
||||
`;
|
||||
}
|
||||
|
||||
export function getDeleteModelQuery(model: ImportedModel): string {
|
||||
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || '');
|
||||
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
|
||||
let updateScript = `
|
||||
Delete from ${twoPartTableName}
|
||||
WHERE model_id = ${model.id}`;
|
||||
|
||||
return `
|
||||
${updateScript}
|
||||
${selectQuery}
|
||||
FROM ${threePartTableName}
|
||||
`;
|
||||
}
|
||||
|
||||
export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id';
|
||||
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as msRest from '@azure/ms-rest-js';
|
||||
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
|
||||
import * as Models from './interfaces';
|
||||
import * as Mappers from './mappers';
|
||||
import * as Parameters from './parameters';
|
||||
|
||||
/**
|
||||
* Workspace models client
|
||||
*/
|
||||
export class WorkspaceModels {
|
||||
private readonly client: AzureMachineLearningWorkspacesContext;
|
||||
|
||||
constructor(client: AzureMachineLearningWorkspacesContext) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
listModels(resourceGroupName: string, workspaceName: string, options?: msRest.RequestOptionsBase): Promise<Models.ListWorkspaceModelsResult>;
|
||||
listModels(resourceGroupName: string, workspaceName: string, callback: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): void;
|
||||
listModels(resourceGroupName: string, workspaceName: string, options: msRest.RequestOptionsBase, callback: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): void;
|
||||
listModels(resourceGroupName: string, workspaceName: string, options?: msRest.RequestOptionsBase | msRest.ServiceCallback<Models.ListWorkspaceModelsResult>, callback?: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): Promise<Models.WorkspacesModelsResponse> {
|
||||
return this.client.sendOperationRequest(
|
||||
{
|
||||
resourceGroupName,
|
||||
workspaceName,
|
||||
options
|
||||
},
|
||||
listModelsOperationSpec,
|
||||
callback) as Promise<Models.WorkspacesModelsResponse>;
|
||||
}
|
||||
}
|
||||
|
||||
const serializer = new msRest.Serializer(Mappers);
|
||||
const listModelsOperationSpec: msRest.OperationSpec = {
|
||||
httpMethod: 'GET',
|
||||
path:
|
||||
'modelmanagement/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}/models',
|
||||
urlParameters: [
|
||||
Parameters.subscriptionId,
|
||||
Parameters.resourceGroupName,
|
||||
Parameters.workspaceName
|
||||
],
|
||||
queryParameters: [
|
||||
Parameters.apiVersion
|
||||
],
|
||||
headerParameters: [
|
||||
Parameters.acceptLanguage
|
||||
],
|
||||
responses: {
|
||||
200: {
|
||||
bodyMapper: Mappers.ListWorkspaceModelsResult
|
||||
},
|
||||
default: {
|
||||
bodyMapper: Mappers.MachineLearningServiceError
|
||||
}
|
||||
},
|
||||
serializer
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
import * as utils from '../common/utils';
|
||||
|
||||
export enum ScriptMode {
|
||||
Install = 'install',
|
||||
Uninstall = 'uninstall'
|
||||
}
|
||||
|
||||
export abstract class SqlPackageManageProviderBase {
|
||||
|
||||
/**
|
||||
* Base class for all SQL package managers
|
||||
*/
|
||||
constructor(protected _apiWrapper: ApiWrapper) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns database names
|
||||
*/
|
||||
public async getLocations(): Promise<nbExtensionApis.IPackageLocation[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let databases = await this._apiWrapper.listDatabases(connection.connectionId);
|
||||
return databases.map(x => {
|
||||
return { displayName: x, name: x };
|
||||
});
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
protected async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs given packages
|
||||
* @param packages Packages to install
|
||||
* @param useMinVersion minimum version
|
||||
*/
|
||||
public async installPackages(packages: nbExtensionApis.IPackageDetails[], useMinVersion: boolean, databaseName: string): Promise<void> {
|
||||
|
||||
if (packages) {
|
||||
await Promise.all(packages.map(x => this.installPackage(x, useMinVersion, databaseName)));
|
||||
}
|
||||
//TODO: use useMinVersion
|
||||
console.log(useMinVersion);
|
||||
}
|
||||
|
||||
private async installPackage(packageDetail: nbExtensionApis.IPackageDetails, useMinVersion: boolean, databaseName: string): Promise<void> {
|
||||
if (useMinVersion) {
|
||||
let packageOverview = await this.getPackageOverview(packageDetail.name);
|
||||
if (packageOverview && packageOverview.versions) {
|
||||
let minVersion = packageOverview.versions[packageOverview.versions.length - 1];
|
||||
packageDetail.version = minVersion;
|
||||
}
|
||||
}
|
||||
|
||||
await this.executeScripts(ScriptMode.Install, packageDetail, databaseName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Uninstalls given packages
|
||||
* @param packages Packages to uninstall
|
||||
*/
|
||||
public async uninstallPackages(packages: nbExtensionApis.IPackageDetails[], databaseName: string): Promise<void> {
|
||||
if (packages) {
|
||||
await Promise.all(packages.map(x => this.executeScripts(ScriptMode.Uninstall, x, databaseName)));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns package overview for given name
|
||||
* @param packageName Package Name
|
||||
*/
|
||||
public async getPackageOverview(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
|
||||
let packageOverview = await this.fetchPackage(packageName);
|
||||
if (packageOverview && packageOverview.versions) {
|
||||
packageOverview.versions = utils.sortPackageVersions(packageOverview.versions, false);
|
||||
}
|
||||
return packageOverview;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of packages
|
||||
*/
|
||||
public async listPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
let packages = await this.fetchPackages(databaseName);
|
||||
if (packages) {
|
||||
packages = packages.sort((a, b) => this.comparePackages(a, b));
|
||||
} else {
|
||||
packages = [];
|
||||
}
|
||||
return packages;
|
||||
}
|
||||
|
||||
private comparePackages(p1: nbExtensionApis.IPackageDetails, p2: nbExtensionApis.IPackageDetails): number {
|
||||
if (p1 && p2) {
|
||||
let compare = p1.name.localeCompare(p2.name);
|
||||
if (compare === 0) {
|
||||
compare = utils.comparePackageVersions(p1.version, p2.version);
|
||||
}
|
||||
return compare;
|
||||
}
|
||||
return p1 ? 1 : -1;
|
||||
}
|
||||
|
||||
protected abstract fetchPackage(packageName: string): Promise<nbExtensionApis.IPackageOverview>;
|
||||
protected abstract fetchPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]>;
|
||||
protected abstract executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void>;
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import * as constants from '../common/constants';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as utils from '../common/utils';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
|
||||
export class PackageManagementService {
|
||||
|
||||
/**
|
||||
* Creates a new instance of ServerConfigManager
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _queryRunner: QueryRunner,
|
||||
) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens server config documents
|
||||
*/
|
||||
public async openDocuments(): Promise<boolean> {
|
||||
return await this._apiWrapper.openExternal(vscode.Uri.parse(constants.mlsDocuments));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if mls is installed in the give SQL server instance
|
||||
*/
|
||||
public async isMachineLearningServiceEnabled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
|
||||
return this._queryRunner.isMachineLearningServiceEnabled(connection);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if R installed in the give SQL server instance
|
||||
*/
|
||||
public async isRInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
|
||||
return this._queryRunner.isRInstalled(connection);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if python installed in the give SQL server instance
|
||||
*/
|
||||
public async isPythonInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
|
||||
return this._queryRunner.isPythonInstalled(connection);
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates external script config
|
||||
* @param connection SQL Connection
|
||||
* @param enable if true external script will be enabled
|
||||
*/
|
||||
public async enableExternalScriptConfig(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
|
||||
let current = await this._queryRunner.isMachineLearningServiceEnabled(connection);
|
||||
|
||||
if (current) {
|
||||
return current;
|
||||
}
|
||||
let confirmed = await utils.promptConfirm(constants.confirmEnableExternalScripts, this._apiWrapper);
|
||||
if (confirmed) {
|
||||
await this._queryRunner.updateExternalScriptConfig(connection, true);
|
||||
current = await this._queryRunner.isMachineLearningServiceEnabled(connection);
|
||||
if (current) {
|
||||
this._apiWrapper.showInfoMessage(constants.mlsEnabledMessage);
|
||||
} else {
|
||||
this._apiWrapper.showErrorMessage(constants.mlsConfigUpdateFailed);
|
||||
}
|
||||
} else {
|
||||
this._apiWrapper.showErrorMessage(constants.externalScriptsIsRequiredError);
|
||||
}
|
||||
|
||||
return current;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns python packages installed in SQL server instance
|
||||
* @param connection SQL Connection
|
||||
*/
|
||||
public async getPythonPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
return this._queryRunner.getPythonPackages(connection, databaseName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns python packages installed in SQL server instance
|
||||
* @param connection SQL Connection
|
||||
*/
|
||||
public async getRPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
return this._queryRunner.getRPackages(connection, databaseName);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
import { SqlPythonPackageManageProvider } from './sqlPythonPackageManageProvider';
|
||||
import * as utils from '../common/utils';
|
||||
import * as constants from '../common/constants';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { ProcessService } from '../common/processService';
|
||||
import { Config } from '../configurations/config';
|
||||
import { isNullOrUndefined } from 'util';
|
||||
import { SqlRPackageManageProvider } from './sqlRPackageManageProvider';
|
||||
import { HttpClient } from '../common/httpClient';
|
||||
import { PackageConfigModel } from '../configurations/packageConfigModel';
|
||||
import { PackageManagementService } from './packageManagementService';
|
||||
|
||||
export class PackageManager {
|
||||
|
||||
private _sqlPythonPackagePackageManager: SqlPythonPackageManageProvider;
|
||||
private _sqlRPackageManager: SqlRPackageManageProvider;
|
||||
public dependenciesInstalled: boolean = false;
|
||||
|
||||
/**
|
||||
* Creates a new instance of PackageManager
|
||||
*/
|
||||
constructor(
|
||||
private _outputChannel: vscode.OutputChannel,
|
||||
private _rootFolder: string,
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _service: PackageManagementService,
|
||||
private _processService: ProcessService,
|
||||
private _config: Config,
|
||||
private _httpClient: HttpClient) {
|
||||
this._sqlPythonPackagePackageManager = new SqlPythonPackageManageProvider(this._outputChannel, this._apiWrapper, this._service, this._processService, this._config, this._httpClient);
|
||||
this._sqlRPackageManager = new SqlRPackageManageProvider(this._outputChannel, this._apiWrapper, this._service, this._processService, this._config, this._httpClient);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the instance and resister SQL package manager with manage package dialog
|
||||
*/
|
||||
public init(): void {
|
||||
}
|
||||
|
||||
private get pythonExecutable(): string {
|
||||
return this._config.pythonExecutable;
|
||||
}
|
||||
|
||||
private get _rExecutable(): string {
|
||||
return this._config.rExecutable;
|
||||
}
|
||||
/**
|
||||
* Returns packageManageProviders
|
||||
*/
|
||||
public get packageManageProviders(): nbExtensionApis.IPackageManageProvider[] {
|
||||
return [
|
||||
this._sqlPythonPackagePackageManager,
|
||||
this._sqlRPackageManager
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes manage package command for SQL server packages.
|
||||
*/
|
||||
public async managePackages(): Promise<void> {
|
||||
try {
|
||||
await this.enableExternalScript();
|
||||
|
||||
// Only execute the command if there's a valid connection with ml configuration enabled
|
||||
//
|
||||
let connection = await this.getCurrentConnection();
|
||||
let isPythonInstalled = await this._service.isPythonInstalled(connection);
|
||||
let isRInstalled = await this._service.isRInstalled(connection);
|
||||
let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined;
|
||||
if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) {
|
||||
defaultProvider = this._sqlPythonPackagePackageManager;
|
||||
} else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) {
|
||||
defaultProvider = this._sqlRPackageManager;
|
||||
}
|
||||
if (connection && defaultProvider) {
|
||||
|
||||
await this.enableExternalScript();
|
||||
// Install dependencies
|
||||
//
|
||||
if (!this.dependenciesInstalled) {
|
||||
await this.installDependencies();
|
||||
this.dependenciesInstalled = true;
|
||||
}
|
||||
|
||||
// Execute the command
|
||||
//
|
||||
this._apiWrapper.executeCommand(constants.managePackagesCommand, {
|
||||
defaultLocation: defaultProvider.packageTarget.location,
|
||||
defaultProviderId: defaultProvider.providerId
|
||||
});
|
||||
} else {
|
||||
this._apiWrapper.showInfoMessage(constants.managePackageCommandError);
|
||||
}
|
||||
} catch (err) {
|
||||
this._apiWrapper.showErrorMessage(err);
|
||||
}
|
||||
}
|
||||
|
||||
public async enableExternalScript(): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (!await this._service.enableExternalScriptConfig(connection)) {
|
||||
throw Error(constants.externalScriptsIsRequiredError);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [
|
||||
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
|
||||
this.installRequiredRPackages()], true);
|
||||
}
|
||||
|
||||
private async installRequiredRPackages(): Promise<void> {
|
||||
if (!this._config.rEnabled) {
|
||||
return;
|
||||
}
|
||||
if (!this._rExecutable) {
|
||||
throw new Error(constants.rConfigError);
|
||||
}
|
||||
|
||||
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
|
||||
await Promise.all(this._config.requiredSqlRPackages.map(x => this.installRPackage(x)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs required python packages
|
||||
*/
|
||||
public async installRequiredPythonPackages(requiredPackages: PackageConfigModel[]): Promise<void> {
|
||||
if (!this._config.pythonEnabled) {
|
||||
return;
|
||||
}
|
||||
if (!this.pythonExecutable) {
|
||||
throw new Error(constants.pythonConfigError);
|
||||
}
|
||||
if (!requiredPackages || requiredPackages.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let installedPackages = await this.getInstalledPipPackages();
|
||||
let fileContent = '';
|
||||
requiredPackages.forEach(packageDetails => {
|
||||
let hasVersion = ('version' in packageDetails) && !isNullOrUndefined(packageDetails['version']) && packageDetails['version'].length > 0;
|
||||
if (!installedPackages.find(x => x.name === packageDetails['name']
|
||||
&& (!hasVersion || utils.comparePackageVersions(packageDetails['version'] || '', x.version) <= 0))) {
|
||||
let packageNameDetail = hasVersion ? `${packageDetails.name}==${packageDetails.version}` : `${packageDetails.name}`;
|
||||
fileContent = `${fileContent}${packageNameDetail}\n`;
|
||||
}
|
||||
});
|
||||
|
||||
if (fileContent) {
|
||||
let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages(fileContent), this._apiWrapper);
|
||||
if (confirmed) {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackages);
|
||||
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
|
||||
return await this.installPipPackage(tempFilePath);
|
||||
});
|
||||
this._outputChannel.appendLine(result);
|
||||
|
||||
} else {
|
||||
throw Error(constants.requiredPackagesNotInstalled);
|
||||
}
|
||||
} else {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackagesAlreadyInstalled);
|
||||
}
|
||||
}
|
||||
|
||||
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
try {
|
||||
let cmd = `"${this.pythonExecutable}" -m pip list --format=json`;
|
||||
let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined);
|
||||
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
|
||||
if (packagesInfo && packagesInfo.indexOf(']') > 0) {
|
||||
packagesResult = <nbExtensionApis.IPackageDetails[]>JSON.parse(packagesInfo.substr(0, packagesInfo.indexOf(']') + 1));
|
||||
}
|
||||
return packagesResult;
|
||||
}
|
||||
catch (err) {
|
||||
this._outputChannel.appendLine(constants.installDependenciesGetPackagesError(err ? err.message : ''));
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private async installPipPackage(requirementFilePath: string): Promise<string> {
|
||||
let cmd = `"${this.pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
|
||||
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
}
|
||||
|
||||
private async installRPackage(model: PackageConfigModel): Promise<string> {
|
||||
let output = '';
|
||||
let cmd = '';
|
||||
if (model.downloadUrl) {
|
||||
const packageFile = utils.getPackageFilePath(this._rootFolder, model.fileName || model.name);
|
||||
const packageExist = await utils.exists(packageFile);
|
||||
if (!packageExist) {
|
||||
await this._httpClient.download(model.downloadUrl, packageFile, this._outputChannel);
|
||||
}
|
||||
cmd = `"${this._rExecutable}" CMD INSTALL ${packageFile}`;
|
||||
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
} else if (model.repository) {
|
||||
cmd = `"${this._rExecutable}" -e "install.packages('${model.name}', repos='${model.repository}')"`;
|
||||
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { ProcessService } from '../common/processService';
|
||||
import { Config } from '../configurations/config';
|
||||
import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProviderBase';
|
||||
import { HttpClient } from '../common/httpClient';
|
||||
import * as utils from '../common/utils';
|
||||
import { PackageManagementService } from './packageManagementService';
|
||||
|
||||
/**
|
||||
* Manage Package Provider for python packages inside SQL server databases
|
||||
*/
|
||||
export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase implements nbExtensionApis.IPackageManageProvider {
|
||||
public static ProviderId = 'sql_Python';
|
||||
|
||||
/**
|
||||
* Creates new a instance
|
||||
*/
|
||||
constructor(
|
||||
private _outputChannel: vscode.OutputChannel,
|
||||
apiWrapper: ApiWrapper,
|
||||
private _service: PackageManagementService,
|
||||
private _processService: ProcessService,
|
||||
private _config: Config,
|
||||
private _httpClient: HttpClient) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns provider Id
|
||||
*/
|
||||
public get providerId(): string {
|
||||
return SqlPythonPackageManageProvider.ProviderId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns package target
|
||||
*/
|
||||
public get packageTarget(): nbExtensionApis.IPackageTarget {
|
||||
return { location: 'SQL', packageType: 'Python' };
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of packages
|
||||
*/
|
||||
protected async fetchPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
return await this._service.getPythonPackages(await this.getCurrentConnection(), databaseName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a script to install or uninstall a python package inside current SQL Server connection
|
||||
* @param packageDetails Packages to install or uninstall
|
||||
* @param scriptMode can be 'install' or 'uninstall'
|
||||
*/
|
||||
protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
|
||||
if (connection) {
|
||||
let port = '1433';
|
||||
let server = connection.serverName;
|
||||
let database = databaseName ? `, database="${databaseName}"` : '';
|
||||
let index = connection.serverName.indexOf(',');
|
||||
if (index > 0) {
|
||||
port = connection.serverName.substring(index + 1);
|
||||
server = connection.serverName.substring(0, index);
|
||||
}
|
||||
|
||||
let pythonConnectionParts = `server="${server}", port=${port}, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"${database})`;
|
||||
let pythonCommandScript = scriptMode === ScriptMode.Install ?
|
||||
`pkgmanager.install(package="${packageDetails.name}", version="${packageDetails.version}")` :
|
||||
`pkgmanager.uninstall(package_name="${packageDetails.name}")`;
|
||||
|
||||
let scripts: string[] = [
|
||||
'import sqlmlutils',
|
||||
`connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 17 for SQL Server", ${pythonConnectionParts}`,
|
||||
'pkgmanager = sqlmlutils.SQLPackageManager(connection)',
|
||||
pythonCommandScript
|
||||
];
|
||||
let pythonExecutable = this._config.pythonExecutable;
|
||||
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the provider can be used
|
||||
*/
|
||||
async canUseProvider(): Promise<boolean> {
|
||||
if (!this._config.pythonEnabled) {
|
||||
return false;
|
||||
}
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection && await this._service.isPythonInstalled(connection)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private getPackageLink(packageName: string): string {
|
||||
return `https://pypi.org/pypi/${packageName}/json`;
|
||||
}
|
||||
|
||||
protected async fetchPackage(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
|
||||
let body = await this._httpClient.fetch(this.getPackageLink(packageName));
|
||||
let packagesJson = JSON.parse(body);
|
||||
let versionNums: string[] = [];
|
||||
let packageSummary = '';
|
||||
if (packagesJson) {
|
||||
if (packagesJson.releases) {
|
||||
let versionKeys = Object.keys(packagesJson.releases);
|
||||
versionKeys = versionKeys.filter(versionKey => {
|
||||
let releaseInfo = packagesJson.releases[versionKey];
|
||||
return Array.isArray(releaseInfo) && releaseInfo.length > 0;
|
||||
});
|
||||
versionNums = utils.sortPackageVersions(versionKeys, false);
|
||||
}
|
||||
|
||||
if (packagesJson.info && packagesJson.info.summary) {
|
||||
packageSummary = packagesJson.info.summary;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: packageName,
|
||||
versions: versionNums,
|
||||
summary: packageSummary
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { ProcessService } from '../common/processService';
|
||||
import { Config } from '../configurations/config';
|
||||
import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProviderBase';
|
||||
import { HttpClient } from '../common/httpClient';
|
||||
import * as constants from '../common/constants';
|
||||
import { PackageManagementService } from './packageManagementService';
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Manage Package Provider for r packages inside SQL server databases
|
||||
*/
|
||||
export class SqlRPackageManageProvider extends SqlPackageManageProviderBase implements nbExtensionApis.IPackageManageProvider {
|
||||
|
||||
public static ProviderId = 'sql_R';
|
||||
|
||||
/**
|
||||
* Creates new a instance
|
||||
*/
|
||||
constructor(
|
||||
private _outputChannel: vscode.OutputChannel,
|
||||
apiWrapper: ApiWrapper,
|
||||
private _service: PackageManagementService,
|
||||
private _processService: ProcessService,
|
||||
private _config: Config,
|
||||
private _httpClient: HttpClient) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns provider Id
|
||||
*/
|
||||
public get providerId(): string {
|
||||
return SqlRPackageManageProvider.ProviderId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns package target
|
||||
*/
|
||||
public get packageTarget(): nbExtensionApis.IPackageTarget {
|
||||
return { location: 'SQL', packageType: 'R' };
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of packages
|
||||
*/
|
||||
protected async fetchPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
return await this._service.getRPackages(await this.getCurrentConnection(), databaseName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a script to install or uninstall a r package inside current SQL Server connection
|
||||
* @param packageDetails Packages to install or uninstall
|
||||
* @param scriptMode can be 'install' or 'uninstall'
|
||||
*/
|
||||
protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
|
||||
if (connection) {
|
||||
let database = databaseName ? `, database="${databaseName}"` : '';
|
||||
let connectionParts = `server="${connection.serverName}", uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"${database}`;
|
||||
let rCommandScript = scriptMode === ScriptMode.Install ? 'sql_install.packages' : 'sql_remove.packages';
|
||||
|
||||
let scripts: string[] = [
|
||||
'formals(quit)$save <- formals(q)$save <- "no"',
|
||||
'library(sqlmlutils)',
|
||||
`connection <- connectionInfo(${connectionParts})`,
|
||||
`r = getOption("repos")`,
|
||||
`r["CRAN"] = "${this._config.rPackagesRepository}"`,
|
||||
`options(repos = r)`,
|
||||
`pkgs <- c("${packageDetails.name}")`,
|
||||
`${rCommandScript}(connectionString = connection, pkgs, scope = "PUBLIC")`,
|
||||
'q()'
|
||||
];
|
||||
let rExecutable = this._config.rExecutable;
|
||||
await this._processService.execScripts(`${rExecutable}`, scripts, ['--vanilla'], this._outputChannel);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the provider can be used
|
||||
*/
|
||||
async canUseProvider(): Promise<boolean> {
|
||||
if (!this._config.rEnabled) {
|
||||
return false;
|
||||
}
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection && await this._service.isRInstalled(connection)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private getPackageLink(packageName: string): string {
|
||||
return `${this._config.rPackagesRepository}/web/packages/${packageName}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns package overview for given name
|
||||
* @param packageName Package Name
|
||||
*/
|
||||
protected async fetchPackage(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
|
||||
let packagePreview: nbExtensionApis.IPackageOverview = {
|
||||
name: packageName,
|
||||
versions: [constants.latestVersion],
|
||||
summary: ''
|
||||
};
|
||||
|
||||
await this._httpClient.fetch(this.getPackageLink(packageName));
|
||||
return packagePreview;
|
||||
}
|
||||
}
|
||||
27
extensions/machine-learning/src/prediction/interfaces.ts
Normal file
27
extensions/machine-learning/src/prediction/interfaces.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
export interface TableColumn {
|
||||
columnName: string;
|
||||
dataType?: string;
|
||||
}
|
||||
|
||||
export interface PredictColumn extends TableColumn {
|
||||
paramName?: string;
|
||||
}
|
||||
|
||||
export interface DatabaseTable {
|
||||
databaseName: string | undefined;
|
||||
tableName: string | undefined;
|
||||
schema: string | undefined
|
||||
}
|
||||
|
||||
export interface PredictInputParameters extends DatabaseTable {
|
||||
inputColumns: PredictColumn[] | undefined
|
||||
}
|
||||
|
||||
export interface PredictParameters extends PredictInputParameters {
|
||||
outputColumns: PredictColumn[] | undefined
|
||||
}
|
||||
217
extensions/machine-learning/src/prediction/predictService.ts
Normal file
217
extensions/machine-learning/src/prediction/predictService.ts
Normal file
@@ -0,0 +1,217 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import * as utils from '../common/utils';
|
||||
import { ImportedModel } from '../modelManagement/interfaces';
|
||||
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* Service to make prediction
|
||||
*/
|
||||
export class PredictService {
|
||||
|
||||
/**
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _queryRunner: QueryRunner) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the list of databases
|
||||
*/
|
||||
public async getDatabaseList(): Promise<string[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
return await this._apiWrapper.listDatabases(connection.connectionId);
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates prediction script given model info and predict parameters
|
||||
* @param predictParams predict parameters
|
||||
* @param registeredModel model parameters
|
||||
*/
|
||||
public async generatePredictScript(
|
||||
predictParams: PredictParameters,
|
||||
registeredModel: ImportedModel | undefined,
|
||||
filePath: string | undefined
|
||||
): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let query = '';
|
||||
if (registeredModel && registeredModel.id) {
|
||||
query = this.getPredictScriptWithModelId(
|
||||
registeredModel.id,
|
||||
predictParams.inputColumns || [],
|
||||
predictParams.outputColumns || [],
|
||||
predictParams,
|
||||
registeredModel.table);
|
||||
} else if (filePath) {
|
||||
let modelBytes = await utils.readFileInHex(filePath || '');
|
||||
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
|
||||
predictParams.outputColumns || [],
|
||||
predictParams);
|
||||
}
|
||||
let document = await this._apiWrapper.openTextDocument({
|
||||
language: 'sql',
|
||||
content: query
|
||||
});
|
||||
await this._apiWrapper.showTextDocument(document.uri);
|
||||
await this._apiWrapper.connect(document.uri.toString(), connection.connectionId);
|
||||
this._apiWrapper.runQuery(document.uri.toString(), undefined, false);
|
||||
return query;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of tables given database name
|
||||
* @param databaseName database name
|
||||
*/
|
||||
public async getTableList(databaseName: string): Promise<DatabaseTable[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: DatabaseTable[] = [];
|
||||
if (connection) {
|
||||
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName));
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push({
|
||||
databaseName: databaseName,
|
||||
tableName: row[0].displayValue,
|
||||
schema: row[1].displayValue
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
/**
|
||||
*Returns list of column names of a database
|
||||
* @param databaseTable table info
|
||||
*/
|
||||
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<TableColumn[]> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let list: TableColumn[] = [];
|
||||
if (connection && databaseTable.databaseName) {
|
||||
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
|
||||
let result = await this._queryRunner.safeRunQuery(connection, query);
|
||||
if (result && result.rows && result.rows.length > 0) {
|
||||
result.rows.forEach(row => {
|
||||
list.push({
|
||||
columnName: row[0].displayValue,
|
||||
dataType: row[1].displayValue
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private getTableColumnsScript(databaseTable: DatabaseTable): string {
|
||||
return `
|
||||
SELECT COLUMN_NAME,DATA_TYPE
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
|
||||
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
|
||||
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
|
||||
`;
|
||||
}
|
||||
|
||||
private getTablesScript(databaseName: string): string {
|
||||
return `
|
||||
SELECT TABLE_NAME,TABLE_SCHEMA
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
|
||||
`;
|
||||
}
|
||||
|
||||
private getPredictScriptWithModelId(
|
||||
modelId: number,
|
||||
columns: PredictColumn[],
|
||||
outputColumns: PredictColumn[],
|
||||
sourceTable: DatabaseTable,
|
||||
importTable: DatabaseTable): string {
|
||||
const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || '');
|
||||
return `
|
||||
DECLARE @model VARBINARY(max) = (
|
||||
SELECT model
|
||||
FROM ${threePartTableName}
|
||||
WHERE model_id = ${modelId}
|
||||
);
|
||||
WITH predict_input
|
||||
AS (
|
||||
SELECT TOP 1000
|
||||
${this.getInputColumnNames(columns, 'pi')}
|
||||
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
|
||||
)
|
||||
SELECT
|
||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
|
||||
FROM PREDICT(MODEL = @model, DATA = predict_input)
|
||||
WITH (
|
||||
${this.getOutputParameters(outputColumns)}
|
||||
) AS p
|
||||
`;
|
||||
}
|
||||
|
||||
private getPredictScriptWithModelBytes(
|
||||
modelBytes: string,
|
||||
columns: PredictColumn[],
|
||||
outputColumns: PredictColumn[],
|
||||
databaseNameTable: DatabaseTable): string {
|
||||
return `
|
||||
WITH predict_input
|
||||
AS (
|
||||
SELECT TOP 1000
|
||||
${this.getInputColumnNames(columns, 'pi')}
|
||||
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
|
||||
)
|
||||
SELECT
|
||||
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')}
|
||||
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
|
||||
WITH (
|
||||
${this.getOutputParameters(outputColumns)}
|
||||
) AS p
|
||||
`;
|
||||
}
|
||||
|
||||
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return this.getColumnName(tableName, c.paramName || '', c.columnName);
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getOutputColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return this.getColumnName(tableName, c.columnName, c.paramName || '');
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getColumnName(tableName: string, columnName: string, displayName: string) {
|
||||
return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`;
|
||||
}
|
||||
|
||||
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
|
||||
return columns.map(c => {
|
||||
return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
|
||||
private getOutputParameters(columns: PredictColumn[]) {
|
||||
return columns.map(c => {
|
||||
return `${c.paramName} ${c.dataType}`;
|
||||
}).join(',\n');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import { ProcessService } from '../../common/processService';
|
||||
import * as utils from '../../common/utils';
|
||||
import should = require('should');
|
||||
|
||||
interface TestContext {
|
||||
|
||||
outputChannel: vscode.OutputChannel;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
return {
|
||||
outputChannel: {
|
||||
name: '',
|
||||
append: () => { },
|
||||
appendLine: () => { },
|
||||
clear: () => { },
|
||||
show: () => { },
|
||||
hide: () => { },
|
||||
dispose: () => { }
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function execFolderListCommand(context: TestContext, service : ProcessService): Promise<string> {
|
||||
if (utils.isWindows()) {
|
||||
return service.execScripts('cmd', ['dir', '.'], [], context.outputChannel);
|
||||
} else {
|
||||
return service.execScripts('/bin/sh', ['-c', 'ls'], [], context.outputChannel);
|
||||
}
|
||||
}
|
||||
|
||||
function execFolderListBufferedCommand(context: TestContext, service : ProcessService): Promise<string> {
|
||||
if (utils.isWindows()) {
|
||||
return service.executeBufferedCommand('dir', context.outputChannel);
|
||||
} else {
|
||||
return service.executeBufferedCommand('ls', context.outputChannel);
|
||||
}
|
||||
}
|
||||
|
||||
describe('Process Service', () => {
|
||||
it('Executing a valid script should return successfully', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
let service = new ProcessService();
|
||||
await should(execFolderListCommand(context, service)).resolved();
|
||||
});
|
||||
|
||||
it('execFolderListCommand should reject if command time out @UNSTABLE@', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
let service = new ProcessService();
|
||||
service.timeout = 10;
|
||||
await should(execFolderListCommand(context, service)).rejected();
|
||||
});
|
||||
|
||||
it('executeBufferedCommand should resolve give valid script', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
let service = new ProcessService();
|
||||
service.timeout = 2000;
|
||||
await should(execFolderListBufferedCommand(context, service)).resolved();
|
||||
});
|
||||
|
||||
});
|
||||
48
extensions/machine-learning/src/test/index.ts
Normal file
48
extensions/machine-learning/src/test/index.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as path from 'path';
|
||||
const testRunner = require('vscodetestcover');
|
||||
|
||||
const suite = 'machine learning Extension Tests';
|
||||
|
||||
const mochaOptions: any = {
|
||||
ui: 'bdd',
|
||||
useColors: true,
|
||||
timeout: 10000
|
||||
};
|
||||
|
||||
// set relevant mocha options from the environment
|
||||
if (process.env.ADS_TEST_GREP) {
|
||||
mochaOptions.grep = process.env.ADS_TEST_GREP;
|
||||
console.log(`setting options.grep to: ${mochaOptions.grep}`);
|
||||
}
|
||||
if (process.env.ADS_TEST_INVERT_GREP) {
|
||||
mochaOptions.invert = parseInt(process.env.ADS_TEST_INVERT_GREP);
|
||||
console.log(`setting options.invert to: ${mochaOptions.invert}`);
|
||||
}
|
||||
if (process.env.ADS_TEST_TIMEOUT) {
|
||||
mochaOptions.timeout = parseInt(process.env.ADS_TEST_TIMEOUT);
|
||||
console.log(`setting options.timeout to: ${mochaOptions.timeout}`);
|
||||
}
|
||||
if (process.env.ADS_TEST_RETRIES) {
|
||||
mochaOptions.retries = parseInt(process.env.ADS_TEST_RETRIES);
|
||||
console.log(`setting options.retries to: ${mochaOptions.retries}`);
|
||||
}
|
||||
|
||||
if (process.env.BUILD_ARTIFACTSTAGINGDIRECTORY) {
|
||||
mochaOptions.reporter = 'mocha-multi-reporters';
|
||||
mochaOptions.reporterOptions = {
|
||||
reporterEnabled: 'spec, mocha-junit-reporter',
|
||||
mochaJunitReporterReporterOptions: {
|
||||
testsuitesTitle: `${suite} ${process.platform}`,
|
||||
mochaFile: path.join(process.env.BUILD_ARTIFACTSTAGINGDIRECTORY, `test-results/${process.platform}-${suite.toLowerCase().replace(/[^\w]/g, '-')}-results.xml`)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
testRunner.configure(mochaOptions, { coverConfig: '../../coverConfig.json' });
|
||||
|
||||
export = testRunner;
|
||||
150
extensions/machine-learning/src/test/mainController.test.ts
Normal file
150
extensions/machine-learning/src/test/mainController.test.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as path from 'path';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { ProcessService } from '../common/processService';
|
||||
import MainController from '../controllers/mainController';
|
||||
import { PackageManager } from '../packageManagement/packageManager';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
|
||||
interface TestContext {
|
||||
notebookExtension: vscode.Extension<any>;
|
||||
jupyterInstallation: nbExtensionApis.IJupyterServerInstallation;
|
||||
jupyterController: nbExtensionApis.IJupyterController;
|
||||
nbExtensionApis: nbExtensionApis.IExtensionApi;
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
processService: TypeMoq.IMock<ProcessService>;
|
||||
context: vscode.ExtensionContext;
|
||||
outputChannel: vscode.OutputChannel;
|
||||
extension: vscode.Extension<any>;
|
||||
packageManager: TypeMoq.IMock<PackageManager>;
|
||||
workspaceConfig: vscode.WorkspaceConfiguration;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
let packages = new Map<string, nbExtensionApis.IPackageManageProvider>();
|
||||
let jupyterInstallation: nbExtensionApis.IJupyterServerInstallation = {
|
||||
installCondaPackages: () => { return Promise.resolve(); },
|
||||
getInstalledPipPackages: () => { return Promise.resolve([]); },
|
||||
installPipPackages: () => { return Promise.resolve(); },
|
||||
uninstallPipPackages: () => { return Promise.resolve(); },
|
||||
uninstallCondaPackages: () => { return Promise.resolve(); },
|
||||
executeBufferedCommand: () => { return Promise.resolve(''); },
|
||||
executeStreamedCommand: () => { return Promise.resolve(); },
|
||||
pythonExecutable: '',
|
||||
pythonInstallationPath: '',
|
||||
installPythonPackage: () => { return Promise.resolve(); }
|
||||
};
|
||||
|
||||
let jupyterController = {
|
||||
jupyterInstallation: jupyterInstallation
|
||||
};
|
||||
|
||||
let extensionPath = path.join(__dirname, '..', '..');
|
||||
let extensionApi: nbExtensionApis.IExtensionApi = {
|
||||
getJupyterController: () => { return jupyterController; },
|
||||
registerPackageManager: (providerId: string, packageManagerProvider: nbExtensionApis.IPackageManageProvider) => {
|
||||
packages.set(providerId, packageManagerProvider);
|
||||
},
|
||||
getPackageManagers: () => { return packages; },
|
||||
};
|
||||
return {
|
||||
jupyterInstallation: jupyterInstallation,
|
||||
jupyterController: jupyterController,
|
||||
nbExtensionApis: extensionApi,
|
||||
notebookExtension: {
|
||||
id: '',
|
||||
extensionPath: '',
|
||||
isActive: true,
|
||||
packageJSON: '',
|
||||
extensionKind: vscode.ExtensionKind.UI,
|
||||
exports: extensionApi,
|
||||
activate: () => {return Promise.resolve();},
|
||||
extensionUri: vscode.Uri.parse('')
|
||||
},
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
processService: TypeMoq.Mock.ofType(ProcessService),
|
||||
packageManager: TypeMoq.Mock.ofType(PackageManager),
|
||||
context: {
|
||||
subscriptions: [],
|
||||
workspaceState: {
|
||||
get: () => {return undefined;},
|
||||
update: () => {return Promise.resolve();}
|
||||
},
|
||||
globalState: {
|
||||
get: () => {return Promise.resolve();},
|
||||
update: () => {return Promise.resolve();}
|
||||
},
|
||||
extensionPath: extensionPath,
|
||||
asAbsolutePath: () => {return '';},
|
||||
storagePath: '',
|
||||
globalStoragePath: '',
|
||||
logPath: '',
|
||||
extensionUri: vscode.Uri.parse('')
|
||||
},
|
||||
outputChannel: {
|
||||
name: '',
|
||||
append: () => { },
|
||||
appendLine: () => { },
|
||||
clear: () => { },
|
||||
show: () => { },
|
||||
hide: () => { },
|
||||
dispose: () => { }
|
||||
},
|
||||
extension: {
|
||||
id: '',
|
||||
extensionPath: '',
|
||||
isActive: true,
|
||||
packageJSON: {},
|
||||
extensionKind: vscode.ExtensionKind.UI,
|
||||
exports: {},
|
||||
activate: () => { return Promise.resolve(); },
|
||||
extensionUri: vscode.Uri.parse('')
|
||||
},
|
||||
workspaceConfig: {
|
||||
get: () => {return 'value';},
|
||||
has: () => {return true;},
|
||||
inspect: () => {return undefined;},
|
||||
update: () => {return Promise.reject();},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function createController(testContext: TestContext): MainController {
|
||||
let controller = new MainController(testContext.context, testContext.apiWrapper.object, testContext.queryRunner.object, testContext.processService.object, testContext.packageManager.object);
|
||||
return controller;
|
||||
}
|
||||
|
||||
describe('Main Controller', () => {
|
||||
it('Should create new instance successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
testContext.apiWrapper.setup(x => x.createOutputChannel(TypeMoq.It.isAny())).returns(() => testContext.outputChannel);
|
||||
should.doesNotThrow(() => createController(testContext));
|
||||
});
|
||||
|
||||
it('initialize Should install dependencies successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
testContext.apiWrapper.setup(x => x.getExtension(TypeMoq.It.isAny())).returns(() => testContext.notebookExtension);
|
||||
testContext.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => testContext.workspaceConfig);
|
||||
testContext.apiWrapper.setup(x => x.createOutputChannel(TypeMoq.It.isAny())).returns(() => testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.getExtension(TypeMoq.It.isAny())).returns(() => testContext.extension);
|
||||
testContext.packageManager.setup(x => x.managePackages()).returns(() => Promise.resolve());
|
||||
testContext.packageManager.setup(x => x.installDependencies()).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.registerCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny()));
|
||||
let controller = createController(testContext);
|
||||
await controller.activate();
|
||||
|
||||
should.notEqual(controller.config.requiredSqlPythonPackages.find(x => x.name ==='sqlmlutils'), undefined);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,232 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { HttpClient } from '../../common/httpClient';
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
|
||||
import * as utils from '../utils';
|
||||
import { Workspace, WorkspacesListByResourceGroupResponse } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { WorkspaceModel, AssetsQueryByIdResponse, Asset, GetArtifactContentInformation2Response } from '../../modelManagement/interfaces';
|
||||
import { AzureMachineLearningWorkspaces, Workspaces } from '@azure/arm-machinelearningservices';
|
||||
import { WorkspaceModels } from '../../modelManagement/workspacesModels';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
httpClient: TypeMoq.IMock<HttpClient>;
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
accounts: azdata.Account[];
|
||||
subscriptions: azureResource.AzureResourceSubscription[];
|
||||
groups: azureResource.AzureResourceResourceGroup[];
|
||||
workspaces: Workspace[];
|
||||
models: WorkspaceModel[];
|
||||
client: TypeMoq.IMock<AzureMachineLearningWorkspaces>;
|
||||
workspacesClient: TypeMoq.IMock<Workspaces>;
|
||||
modelClient: TypeMoq.IMock<WorkspaceModels>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
const context = utils.createContext();
|
||||
const workspaces = TypeMoq.Mock.ofType(Workspaces);
|
||||
const credentials = {
|
||||
signRequest: () => {
|
||||
return Promise.resolve(undefined!!);
|
||||
}
|
||||
};
|
||||
const client = TypeMoq.Mock.ofInstance(new AzureMachineLearningWorkspaces(credentials, 'subscription'));
|
||||
client.setup(x => x.apiVersion).returns(() => '20180101');
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
httpClient: TypeMoq.Mock.ofType(HttpClient),
|
||||
outputChannel: context.outputChannel,
|
||||
op: context.op,
|
||||
accounts: [
|
||||
{
|
||||
key: {
|
||||
providerId: '',
|
||||
accountId: 'a1'
|
||||
},
|
||||
displayInfo: {
|
||||
contextualDisplayName: '',
|
||||
accountType: '',
|
||||
displayName: 'a1',
|
||||
userId: 'a1'
|
||||
},
|
||||
properties:
|
||||
{
|
||||
tenants: [
|
||||
{
|
||||
id: '1',
|
||||
}
|
||||
]
|
||||
}
|
||||
,
|
||||
isStale: true
|
||||
}
|
||||
],
|
||||
subscriptions: [
|
||||
{
|
||||
name: 's1',
|
||||
id: 's1'
|
||||
}
|
||||
],
|
||||
groups: [
|
||||
{
|
||||
name: 'g1',
|
||||
id: 'g1'
|
||||
}
|
||||
],
|
||||
workspaces: [{
|
||||
name: 'w1',
|
||||
id: 'w1'
|
||||
}
|
||||
],
|
||||
models: [
|
||||
{
|
||||
name: 'm1',
|
||||
id: 'm1',
|
||||
url: 'aml://asset/test.test'
|
||||
}
|
||||
],
|
||||
client: client,
|
||||
workspacesClient: workspaces,
|
||||
modelClient: TypeMoq.Mock.ofInstance(new WorkspaceModels(client.object))
|
||||
};
|
||||
}
|
||||
|
||||
describe('AzureModelRegistryService', () => {
|
||||
it('getAccounts should return the list of accounts successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const accounts = testContext.accounts;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.getAllAccounts()).returns(() => Promise.resolve(accounts));
|
||||
let actual = await service.getAccounts();
|
||||
should.deepEqual(actual, testContext.accounts);
|
||||
});
|
||||
|
||||
it('getSubscriptions should return the list of subscriptions successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const expected = testContext.subscriptions;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ subscriptions: expected, errors: [] }));
|
||||
let actual = await service.getSubscriptions(testContext.accounts[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getGroups should return the list of groups successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const expected = testContext.groups;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ resourceGroups: expected, errors: [] }));
|
||||
let actual = await service.getGroups(testContext.accounts[0], testContext.subscriptions[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getWorkspaces should return the list of workspaces successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const response: WorkspacesListByResourceGroupResponse = Object.assign(new Array<Workspace>(...testContext.workspaces), {
|
||||
_response: undefined!
|
||||
});
|
||||
const expected = testContext.workspaces;
|
||||
testContext.workspacesClient.setup(x => x.listByResourceGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(response));
|
||||
testContext.workspacesClient.setup(x => x.listBySubscription()).returns(() => Promise.resolve(response));
|
||||
testContext.client.setup(x => x.workspaces).returns(() => testContext.workspacesClient.object);
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
|
||||
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
let actual = await service.getWorkspaces(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getModels should return the list of models successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
|
||||
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
|
||||
const expected = testContext.models;
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
service.ModelClient = testContext.modelClient.object;
|
||||
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
|
||||
let actual = await service.getModels(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0]);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('downloadModel should download model artifact successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const asset: Asset =
|
||||
{
|
||||
id: '1',
|
||||
name: 'asset',
|
||||
artifacts: [
|
||||
{
|
||||
id: '/1/2/3/4/5/'
|
||||
}
|
||||
]
|
||||
};
|
||||
const assetResponse: AssetsQueryByIdResponse = Object.assign(asset, {
|
||||
_response: undefined!
|
||||
});
|
||||
const artifactResponse: GetArtifactContentInformation2Response = Object.assign({
|
||||
contentUri: 'downloadUrl'
|
||||
}, {
|
||||
_response: undefined!
|
||||
});
|
||||
|
||||
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
|
||||
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
|
||||
testContext.config.setup(x => x.amlExperienceUrl).returns(() => 'test.url');
|
||||
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
|
||||
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('modelmanagement')), TypeMoq.It.isAny())).returns(() => Promise.resolve(assetResponse));
|
||||
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
|
||||
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('artifact')), TypeMoq.It.isAny())).returns(() => Promise.resolve(artifactResponse));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
let service = new AzureModelRegistryService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object,
|
||||
testContext.outputChannel);
|
||||
service.AzureMachineLearningClient = testContext.client.object;
|
||||
service.ModelClient = testContext.modelClient.object;
|
||||
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
|
||||
let actual = await service.downloadModel(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0], testContext.models[0]);
|
||||
should.notEqual(actual, undefined);
|
||||
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,453 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as utils from '../../common/utils';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { ImportedModel } from '../../modelManagement/interfaces';
|
||||
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as fs from 'fs';
|
||||
import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent';
|
||||
import { DatabaseTable } from '../../prediction/interfaces';
|
||||
import * as queries from '../../modelManagement/queries';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
modelClient: TypeMoq.IMock<ModelPythonClient>;
|
||||
recentModels: TypeMoq.IMock<ModelConfigRecent>;
|
||||
importTable: DatabaseTable;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
modelClient: TypeMoq.Mock.ofType(ModelPythonClient),
|
||||
recentModels: TypeMoq.Mock.ofType(ModelConfigRecent),
|
||||
importTable: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
describe('DeployedModelService', () => {
|
||||
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
let importTable: DatabaseTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
};
|
||||
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
await should(service.getDeployedModels(importTable)).rejected();
|
||||
});
|
||||
|
||||
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const expected: ImportedModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
deploymentTime: '2018-01-01',
|
||||
version: '1.1',
|
||||
framework: 'onnx',
|
||||
frameworkVersion: '1',
|
||||
deployedBy: '1',
|
||||
runId: 'run1',
|
||||
table: testContext.importTable
|
||||
|
||||
}
|
||||
];
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [
|
||||
[
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'name1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'desc1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1.1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'onnx',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'run1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]
|
||||
]
|
||||
};
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
const actual = await service.getDeployedModels(testContext.importTable);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('loadModelParameters should load parameters using python client successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
},
|
||||
]
|
||||
};
|
||||
testContext.modelClient.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(expected));
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
const actual = await service.loadModelParameters('');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('downloadModel should download model successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const model: ImportedModel =
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
deploymentTime: '2018-01-01',
|
||||
version: '1.1',
|
||||
framework: 'onnx',
|
||||
frameworkVersion: '1',
|
||||
deployedBy: '1',
|
||||
runId: 'run1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [
|
||||
[
|
||||
{
|
||||
displayValue: await utils.readFileInHex(tempFilePath),
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]
|
||||
]
|
||||
};
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
const actual = await service.downloadModel(model);
|
||||
should.notEqual(actual, undefined);
|
||||
});
|
||||
|
||||
it('deployLocalModel should returns models successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const model: ImportedModel =
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
deploymentTime: '2018-01-01',
|
||||
version: '1.1',
|
||||
framework: 'onnx',
|
||||
frameworkVersion: '1',
|
||||
deployedBy: '1',
|
||||
runId: 'run1',
|
||||
table: testContext.importTable
|
||||
|
||||
};
|
||||
const row = [
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'name1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'desc1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1.1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'onnx',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '2018-01-01',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'run1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
];
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [row]
|
||||
};
|
||||
let updatedResult = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [row, row]
|
||||
};
|
||||
let deployed = false;
|
||||
let service = new DeployedModelService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.config.object,
|
||||
testContext.queryRunner.object,
|
||||
testContext.modelClient.object,
|
||||
testContext.recentModels.object);
|
||||
|
||||
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('INSERT INTO') > 0), TypeMoq.It.isAny())).returns(() => {
|
||||
deployed = true;
|
||||
return Promise.resolve(result);
|
||||
});
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
|
||||
});
|
||||
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
|
||||
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
|
||||
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
|
||||
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
|
||||
let tempFilePath: string = '';
|
||||
try {
|
||||
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
await should(service.deployLocalModel(tempFilePath, model, testContext.importTable)).resolved();
|
||||
}
|
||||
finally {
|
||||
await utils.deleteFile(tempFilePath);
|
||||
}
|
||||
});
|
||||
|
||||
it('getConfigureQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
|
||||
testContext.importTable.databaseName = 'd[]b';
|
||||
testContext.importTable.tableName = 'ta[b]le';
|
||||
testContext.importTable.schema = 'dbo';
|
||||
const expected = `
|
||||
IF NOT EXISTS
|
||||
( SELECT t.name, s.name
|
||||
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
|
||||
WHERE t.name = 'ta[b]le'
|
||||
AND s.name = 'dbo'
|
||||
)
|
||||
BEGIN
|
||||
CREATE TABLE [dbo].[ta[[b]]le](
|
||||
[model_id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[model_name] [varchar](256) NOT NULL,
|
||||
[model_framework] [varchar](256) NULL,
|
||||
[model_framework_version] [varchar](256) NULL,
|
||||
[model] [varbinary](max) NOT NULL,
|
||||
[model_version] [varchar](256) NULL,
|
||||
[model_creation_time] [datetime2] NULL,
|
||||
[model_deployment_time] [datetime2] NULL,
|
||||
[deployed_by] [int] NULL,
|
||||
[model_description] [varchar](256) NULL,
|
||||
[run_id] [varchar](256) NULL,
|
||||
CONSTRAINT [ta[[b]]le_models_pk] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[model_id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
|
||||
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [ta[[b]]le_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
|
||||
END
|
||||
`;
|
||||
const actual = queries.getConfigureTableQuery(testContext.importTable);
|
||||
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
|
||||
});
|
||||
|
||||
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
testContext.importTable.databaseName = 'd[]b';
|
||||
testContext.importTable.tableName = 'ta[b]le';
|
||||
testContext.importTable.schema = 'dbo';
|
||||
const expected = `
|
||||
SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
|
||||
ORDER BY model_id
|
||||
`;
|
||||
const actual = queries.getDeployedModelsQuery(testContext.importTable);
|
||||
should.deepEqual(expected, actual);
|
||||
});
|
||||
|
||||
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const model: ImportedModel =
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
|
||||
const expected = `INSERT INTO [dbo].[tb]
|
||||
(model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id)
|
||||
VALUES (
|
||||
'name1',
|
||||
,
|
||||
'1.1',
|
||||
'desc1',
|
||||
'2018-01-01',
|
||||
'',
|
||||
'',
|
||||
'')`;
|
||||
const actual = queries.getInsertModelQuery(model, testContext.importTable);
|
||||
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
|
||||
});
|
||||
|
||||
it('getModelContentQuery should escape db name', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const model: ImportedModel =
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
|
||||
model.table = {
|
||||
databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo'
|
||||
};
|
||||
const expected = `
|
||||
SELECT model
|
||||
FROM [d[[]]b].[dbo].[ta[[b]]le]
|
||||
WHERE model_id = 1;
|
||||
`;
|
||||
const actual = queries.getModelContentQuery(model);
|
||||
should.deepEqual(actual, expected, `actual: ${actual} \n expected: ${expected}`);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,121 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { Config } from '../../configurations/config';
|
||||
|
||||
import * as utils from '../utils';
|
||||
import { ProcessService } from '../../common/processService';
|
||||
import { PackageManager } from '../../packageManagement/packageManager';
|
||||
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
processService: TypeMoq.IMock<ProcessService>;
|
||||
packageManager: TypeMoq.IMock<PackageManager>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
const context = utils.createContext();
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
outputChannel: context.outputChannel,
|
||||
op: context.op,
|
||||
processService: TypeMoq.Mock.ofType(ProcessService),
|
||||
packageManager: TypeMoq.Mock.ofType(PackageManager)
|
||||
};
|
||||
}
|
||||
|
||||
describe('ModelPythonClient', () => {
|
||||
it('deployModel should deploy the model successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
const modelPath = 'C:\\test';
|
||||
let service = new ModelPythonClient(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.packageManager.object);
|
||||
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
||||
|
||||
await service.deployModel(connection, modelPath);
|
||||
});
|
||||
|
||||
it('loadModelParameters should load model parameters successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const modelPath = 'C:\\test';
|
||||
const expected = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
},
|
||||
]
|
||||
};
|
||||
const parametersJson = `
|
||||
{
|
||||
"inputs": [
|
||||
{
|
||||
"name": "p1",
|
||||
"type": "int"
|
||||
},
|
||||
{
|
||||
"name": "p2",
|
||||
"type": "varchar"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "o1",
|
||||
"type": "int"
|
||||
}
|
||||
]
|
||||
}
|
||||
`;
|
||||
let service = new ModelPythonClient(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.packageManager.object);
|
||||
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
|
||||
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
|
||||
const actual = await service.loadModelParameters(modelPath);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,73 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { PackageManagementService } from '../../packageManagement/packageManagementService';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
|
||||
};
|
||||
}
|
||||
|
||||
describe('Package Management Service', () => {
|
||||
it('openDocuments should open document in browser successfully', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
context.apiWrapper.setup(x => x.openExternal(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
|
||||
should.equal(await serverConfigManager.openDocuments(), true);
|
||||
});
|
||||
|
||||
it('isMachineLearningServiceEnabled should return true if external script is enabled', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
context.queryRunner.setup(x => x.isMachineLearningServiceEnabled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
should.equal(await serverConfigManager.isMachineLearningServiceEnabled(connection), true);
|
||||
});
|
||||
|
||||
it('isRInstalled should return true if R is installed', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
context.queryRunner.setup(x => x.isRInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
should.equal(await serverConfigManager.isRInstalled(connection), true);
|
||||
});
|
||||
|
||||
it('isPythonInstalled should return true if Python is installed', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
context.queryRunner.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
should.equal(await serverConfigManager.isPythonInstalled(connection), true);
|
||||
});
|
||||
|
||||
it('enableExternalScriptConfig should show error message if did not updated successfully', async function (): Promise<void> {
|
||||
const context = createContext();
|
||||
context.queryRunner.setup(x => x.updateExternalScriptConfig(TypeMoq.It.isAny(), true)).returns(() => Promise.resolve());
|
||||
context.queryRunner.setup(x => x.isMachineLearningServiceEnabled(TypeMoq.It.isAny())).returns(() => Promise.resolve(false));
|
||||
context.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
||||
context.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
||||
context.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
await serverConfigManager.enableExternalScriptConfig(connection);
|
||||
|
||||
context.apiWrapper.verify(x => x.showErrorMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,273 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { PackageManager } from '../../packageManagement/packageManager';
|
||||
import { createContext, TestContext } from './utils';
|
||||
|
||||
describe('Package Manager', () => {
|
||||
it('Should initialize SQL package manager successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
should.doesNotThrow(() => createPackageManager(testContext));
|
||||
});
|
||||
|
||||
it('Manage Package command Should execute the command for valid connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.managePackages();
|
||||
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
|
||||
it('Manage Package command Should execute the command if r installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
|
||||
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.managePackages();
|
||||
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
|
||||
it('Manage Package command Should show an error for connection without python installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
|
||||
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny()));
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
|
||||
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(false);});
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
|
||||
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.managePackages();
|
||||
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
|
||||
it('Manage Package command Should show an error for no connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
|
||||
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
|
||||
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny()));
|
||||
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.managePackages();
|
||||
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
|
||||
it('installDependencies Should download sqlmlutils if does not exist', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"},
|
||||
{"name":"sqlmlutils","version":"1.1.1"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
|
||||
});
|
||||
|
||||
it('installDependencies Should not install packages if already installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesInstalled = false;
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"},
|
||||
{"name":"sqlmlutils","version":"1.1.1"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
|
||||
if (command.indexOf('pip install') > 0) {
|
||||
packagesInstalled = true;
|
||||
}
|
||||
return Promise.resolve(installedPackages);
|
||||
});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
should.equal(packagesInstalled, false);
|
||||
});
|
||||
|
||||
it('installDependencies Should install packages that are not already installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesInstalled = false;
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
|
||||
if (command.indexOf('pip install') > 0) {
|
||||
packagesInstalled = true;
|
||||
}
|
||||
return Promise.resolve(installedPackages);
|
||||
});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
should.equal(packagesInstalled, true);
|
||||
});
|
||||
|
||||
it('installDependencies Should not install packages if runtime is disabled in setting', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
testContext.config.setup(x => x.rEnabled).returns(() => false);
|
||||
testContext.config.setup(x => x.pythonEnabled).returns(() => false);
|
||||
let packagesInstalled = false;
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
|
||||
if (command.indexOf('pip install') > 0 || command.indexOf('install.packages') > 0) {
|
||||
packagesInstalled = true;
|
||||
}
|
||||
return Promise.resolve(installedPackages);
|
||||
});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
should.equal(packagesInstalled, false);
|
||||
});
|
||||
|
||||
it('installDependencies Should install packages that have older version installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesInstalled = false;
|
||||
let installedPackages = `[
|
||||
{"name":"sqlmlutils","version":"0.1.1"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
|
||||
if (command.indexOf('pip install') > 0) {
|
||||
packagesInstalled = true;
|
||||
}
|
||||
return Promise.resolve(installedPackages);
|
||||
});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
should.equal(packagesInstalled, true);
|
||||
});
|
||||
|
||||
it('installDependencies Should install packages if list packages fails', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesInstalled = false;
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
|
||||
label: 'Yes'
|
||||
}));
|
||||
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command,) => {
|
||||
if (command.indexOf('pip list') > 0) {
|
||||
return Promise.reject();
|
||||
} else if (command.indexOf('pip install') > 0) {
|
||||
packagesInstalled = true;
|
||||
return Promise.resolve('');
|
||||
} else {
|
||||
return Promise.resolve('');
|
||||
}
|
||||
});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await packageManager.installDependencies();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
|
||||
should.equal(packagesInstalled, true);
|
||||
});
|
||||
|
||||
it('installDependencies Should fail if download packages fails', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesInstalled = false;
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject());
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
|
||||
if (command.indexOf('pip list') > 0) {
|
||||
return Promise.resolve(installedPackages);
|
||||
} else if (command.indexOf('pip install') > 0) {
|
||||
return Promise.reject();
|
||||
} else {
|
||||
return Promise.resolve('');
|
||||
}
|
||||
});
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await should(packageManager.installDependencies()).rejected();
|
||||
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Failed);
|
||||
should.equal(packagesInstalled, false);
|
||||
});
|
||||
|
||||
function createPackageManager(testContext: TestContext): PackageManager {
|
||||
testContext.config.setup(x => x.requiredSqlPythonPackages).returns( () => [
|
||||
{ name: 'pymssql', version: '2.1.4' },
|
||||
{ name: 'sqlmlutils', version: '' }
|
||||
]);
|
||||
testContext.config.setup(x => x.requiredSqlRPackages).returns( () => [
|
||||
{ name: 'RODBCext', repository: 'https://cran.microsoft.com' },
|
||||
{ name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'}
|
||||
]);
|
||||
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
|
||||
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
|
||||
testContext.config.setup(x => x.rEnabled).returns(() => true);
|
||||
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
|
||||
let packageManager = new PackageManager(
|
||||
testContext.outputChannel,
|
||||
'',
|
||||
testContext.apiWrapper.object,
|
||||
testContext.serverConfigManager.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object);
|
||||
packageManager.init();
|
||||
packageManager.dependenciesInstalled = true;
|
||||
return packageManager;
|
||||
}
|
||||
});
|
||||
@@ -0,0 +1,399 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { SqlPythonPackageManageProvider } from '../../packageManagement/sqlPythonPackageManageProvider';
|
||||
import { createContext, TestContext } from './utils';
|
||||
import * as nbExtensionApis from '../../typings/notebookServices';
|
||||
|
||||
describe('SQL Python Package Manager', () => {
|
||||
it('Should create SQL package manager successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
should.doesNotThrow(() => createProvider(testContext));
|
||||
});
|
||||
|
||||
it('Should return provider Id and target correctly', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let provider = createProvider(testContext);
|
||||
should.deepEqual(SqlPythonPackageManageProvider.ProviderId, provider.providerId);
|
||||
should.deepEqual({ location: 'SQL', packageType: 'Python' }, provider.packageTarget);
|
||||
});
|
||||
|
||||
it('listPackages Should return packages sorted by name', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
},
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.listPackages(connection.databaseName);
|
||||
let expected = [
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
}
|
||||
];
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('listPackages Should return packages sorted by name and version', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.2'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.listPackages(connection.databaseName);
|
||||
let expected = [
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.2'
|
||||
}
|
||||
];
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('listPackages Should return empty packages if undefined packages returned', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let packages: nbExtensionApis.IPackageDetails[];
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.listPackages(connection.databaseName);
|
||||
let expected: nbExtensionApis.IPackageDetails[] = [];
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('listPackages Should return empty packages if empty packages returned', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve([]));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.listPackages(connection.databaseName);
|
||||
let expected: nbExtensionApis.IPackageDetails[] = [];
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('installPackages Should install given packages successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
connection.serverName = 'serverName';
|
||||
connection.databaseName = 'databaseName';
|
||||
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
|
||||
|
||||
if (path && scripts.find(x => x.indexOf('install') > 0) &&
|
||||
scripts.find(x => x.indexOf('port=1433') > 0) &&
|
||||
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('package="a-name"') > 0) &&
|
||||
scripts.find(x => x.indexOf('version="1.1.2"') > 0) &&
|
||||
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
|
||||
packagesUpdated = true;
|
||||
}
|
||||
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.installPackages(packages, false, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, true);
|
||||
});
|
||||
|
||||
it('uninstallPackages Should uninstall given packages successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
connection.serverName = 'serverName';
|
||||
connection.databaseName = 'databaseName';
|
||||
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
|
||||
|
||||
if (path && scripts.find(x => x.indexOf('uninstall') > 0) &&
|
||||
scripts.find(x => x.indexOf('port=1433') > 0) &&
|
||||
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('package_name="a-name"') > 0) &&
|
||||
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
|
||||
packagesUpdated = true;
|
||||
}
|
||||
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.uninstallPackages(packages, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, true);
|
||||
});
|
||||
|
||||
it('installPackages Should include port name in the script', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
connection.serverName = 'serverName,3433';
|
||||
connection.databaseName = 'databaseName';
|
||||
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
|
||||
|
||||
if (path && scripts.find(x => x.indexOf('install') > 0) &&
|
||||
scripts.find(x => x.indexOf('port=3433') > 0) &&
|
||||
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('package="a-name"') > 0) &&
|
||||
scripts.find(x => x.indexOf('version="1.1.2"') > 0) &&
|
||||
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
|
||||
packagesUpdated = true;
|
||||
}
|
||||
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.installPackages(packages, false, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, true);
|
||||
});
|
||||
|
||||
it('installPackages Should not install any packages give empty list', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
packagesUpdated = true;
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.installPackages(packages, false, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, false);
|
||||
});
|
||||
|
||||
it('uninstallPackages Should not uninstall any packages give empty list', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
packagesUpdated = true;
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.uninstallPackages(packages, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, false);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return false for no connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, false);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return false if connection does not have python installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(false));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, false);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return true if connection has python installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, true);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return false if python is disabled in setting', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
testContext.config.setup(x => x.pythonEnabled).returns(() => false);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, false);
|
||||
});
|
||||
|
||||
it('getPackageOverview Should return package info using python packages provider', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagePreview = {
|
||||
name: 'package name',
|
||||
versions: ['0.0.2', '0.0.1'],
|
||||
summary: 'package summary'
|
||||
};
|
||||
testContext.httpClient.setup(x => x.fetch(TypeMoq.It.isAny())).returns(() => {
|
||||
return Promise.resolve(`{"info":{"summary":"package summary"}, "releases":{"0.0.1":[{"comment_text":""}], "0.0.2":[{"comment_text":""}]}}`);
|
||||
});
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getPackageOverview('package name');
|
||||
|
||||
should.deepEqual(actual, packagePreview);
|
||||
});
|
||||
|
||||
it('getLocations Should return empty array for no connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocations();
|
||||
|
||||
should.deepEqual(actual, []);
|
||||
});
|
||||
|
||||
it('getLocations Should return database names for valid connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
connection.serverName = 'serverName';
|
||||
connection.databaseName = 'databaseName';
|
||||
const databaseNames = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const expected = [
|
||||
{
|
||||
displayName: 'db1',
|
||||
name: 'db1'
|
||||
},
|
||||
{
|
||||
displayName: 'db2',
|
||||
name: 'db2'
|
||||
}
|
||||
];
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.listDatabases(connection.connectionId)).returns(() => { return Promise.resolve(databaseNames); });
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocations();
|
||||
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
function createProvider(testContext: TestContext): SqlPythonPackageManageProvider {
|
||||
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
|
||||
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
|
||||
return new SqlPythonPackageManageProvider(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.serverConfigManager.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object);
|
||||
}
|
||||
});
|
||||
@@ -0,0 +1,325 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { SqlRPackageManageProvider } from '../../packageManagement/sqlRPackageManageProvider';
|
||||
import { createContext, TestContext } from './utils';
|
||||
import * as nbExtensionApis from '../../typings/notebookServices';
|
||||
|
||||
describe('SQL R Package Manager', () => {
|
||||
it('Should create SQL package manager successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
should.doesNotThrow(() => createProvider(testContext));
|
||||
});
|
||||
|
||||
it('Should return provider Id and target correctly', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let provider = createProvider(testContext);
|
||||
should.deepEqual(SqlRPackageManageProvider.ProviderId, provider.providerId);
|
||||
should.deepEqual({ location: 'SQL', packageType: 'R' }, provider.packageTarget);
|
||||
});
|
||||
|
||||
it('listPackages Should return packages sorted by name', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
},
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.listPackages(connection.databaseName);
|
||||
let expected = [
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
}
|
||||
];
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('listPackages Should return empty packages if undefined packages returned', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let packages: nbExtensionApis.IPackageDetails[];
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.listPackages(connection.databaseName);
|
||||
let expected: nbExtensionApis.IPackageDetails[] = [];
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('listPackages Should return empty packages if empty packages returned', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve([]));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.listPackages(connection.databaseName);
|
||||
let expected: nbExtensionApis.IPackageDetails[] = [];
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('installPackages Should install given packages successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
connection.serverName = 'serverName';
|
||||
connection.databaseName = 'databaseName';
|
||||
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
|
||||
|
||||
if (path && scripts.find(x => x.indexOf('install') > 0) &&
|
||||
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('"a-name"') > 0) &&
|
||||
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
|
||||
packagesUpdated = true;
|
||||
}
|
||||
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.installPackages(packages, false, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, true);
|
||||
});
|
||||
|
||||
it('uninstallPackages Should uninstall given packages successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
{
|
||||
'name': 'a-name',
|
||||
'version': '1.1.2'
|
||||
},
|
||||
{
|
||||
'name': 'b-name',
|
||||
'version': '1.1.1'
|
||||
}
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
connection.serverName = 'serverName';
|
||||
connection.databaseName = 'databaseName';
|
||||
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
|
||||
|
||||
if (path && scripts.find(x => x.indexOf('remove') > 0) &&
|
||||
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
|
||||
scripts.find(x => x.indexOf('"a-name"') > 0) &&
|
||||
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
|
||||
packagesUpdated = true;
|
||||
}
|
||||
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.uninstallPackages(packages, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, true);
|
||||
});
|
||||
|
||||
it('installPackages Should not install any packages give empty list', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
packagesUpdated = true;
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.installPackages(packages, false, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, false);
|
||||
});
|
||||
|
||||
it('uninstallPackages Should not uninstall any packages give empty list', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesUpdated = false;
|
||||
let packages: nbExtensionApis.IPackageDetails[] = [
|
||||
];
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
packagesUpdated = true;
|
||||
return Promise.resolve('');
|
||||
});
|
||||
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
await provider.uninstallPackages(packages, connection.databaseName);
|
||||
|
||||
should.deepEqual(packagesUpdated, false);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return false for no connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, false);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return false if connection does not have r installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.isRInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(false));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, false);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return true if connection has r installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.serverConfigManager.setup(x => x.isRInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, true);
|
||||
});
|
||||
|
||||
it('canUseProvider Should return false if r is disabled in setting', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
testContext.config.setup(x => x.rEnabled).returns(() => false);
|
||||
let actual = await provider.canUseProvider();
|
||||
|
||||
should.deepEqual(actual, false);
|
||||
});
|
||||
|
||||
it('getPackageOverview Should return package info successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagePreview = {
|
||||
'name': 'a-name',
|
||||
'versions': ['Latest'],
|
||||
'summary': ''
|
||||
};
|
||||
|
||||
testContext.httpClient.setup(x => x.fetch(TypeMoq.It.isAny())).returns(() => {
|
||||
return Promise.resolve(``);
|
||||
});
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getPackageOverview('a-name');
|
||||
|
||||
should.deepEqual(actual, packagePreview);
|
||||
});
|
||||
|
||||
it('getLocations Should return empty array for no connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection: azdata.connection.ConnectionProfile;
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocations();
|
||||
|
||||
should.deepEqual(actual, []);
|
||||
});
|
||||
|
||||
it('getLocations Should return database names for valid connection', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
connection.serverName = 'serverName';
|
||||
connection.databaseName = 'databaseName';
|
||||
const databaseNames = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const expected = [
|
||||
{
|
||||
displayName: 'db1',
|
||||
name: 'db1'
|
||||
},
|
||||
{
|
||||
displayName: 'db2',
|
||||
name: 'db2'
|
||||
}
|
||||
];
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.listDatabases(connection.connectionId)).returns(() => { return Promise.resolve(databaseNames); });
|
||||
|
||||
let provider = createProvider(testContext);
|
||||
let actual = await provider.getLocations();
|
||||
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
function createProvider(testContext: TestContext): SqlRPackageManageProvider {
|
||||
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
|
||||
testContext.config.setup(x => x.rEnabled).returns(() => true);
|
||||
testContext.config.setup(x => x.rPackagesRepository).returns(() => 'http://cran.r-project.org');
|
||||
return new SqlRPackageManageProvider(
|
||||
testContext.outputChannel,
|
||||
testContext.apiWrapper.object,
|
||||
testContext.serverConfigManager.object,
|
||||
testContext.processService.object,
|
||||
testContext.config.object,
|
||||
testContext.httpClient.object);
|
||||
}
|
||||
});
|
||||
@@ -0,0 +1,45 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { ProcessService } from '../../common/processService';
|
||||
import { Config } from '../../configurations/config';
|
||||
import { HttpClient } from '../../common/httpClient';
|
||||
import * as utils from '../utils';
|
||||
import { PackageManagementService } from '../../packageManagement/packageManagementService';
|
||||
|
||||
export interface TestContext {
|
||||
|
||||
outputChannel: vscode.OutputChannel;
|
||||
processService: TypeMoq.IMock<ProcessService>;
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
config: TypeMoq.IMock<Config>;
|
||||
op: azdata.BackgroundOperation;
|
||||
getOpStatus: () => azdata.TaskStatus;
|
||||
httpClient: TypeMoq.IMock<HttpClient>;
|
||||
serverConfigManager: TypeMoq.IMock<PackageManagementService>;
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
const context = utils.createContext();
|
||||
|
||||
return {
|
||||
|
||||
outputChannel: context.outputChannel,
|
||||
processService: TypeMoq.Mock.ofType(ProcessService),
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
|
||||
config: TypeMoq.Mock.ofType(Config),
|
||||
httpClient: TypeMoq.Mock.ofType(HttpClient),
|
||||
op: context.op,
|
||||
getOpStatus: context.getOpStatus,
|
||||
serverConfigManager: TypeMoq.Mock.ofType(PackageManagementService)
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
import { QueryRunner } from '../../common/queryRunner';
|
||||
import { ImportedModel } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import * as path from 'path';
|
||||
import * as os from 'os';
|
||||
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
|
||||
import * as fs from 'fs';
|
||||
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
importTable: DatabaseTable;
|
||||
queryRunner: TypeMoq.IMock<QueryRunner>;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
importTable: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
},
|
||||
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
|
||||
};
|
||||
}
|
||||
|
||||
describe('PredictService', () => {
|
||||
|
||||
it('getDatabaseList should return databases successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.listDatabases(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(expected); });
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object);
|
||||
const actual = await service.getDatabaseList();
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getTableList should return tables successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [[
|
||||
{
|
||||
displayValue: 'tb1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'dbo',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
], [
|
||||
{
|
||||
displayValue: 'tb2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'dbo',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]]
|
||||
};
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object);
|
||||
const actual = await service.getTableList('db1');
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getTableColumnsList should return table columns successfully', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const expected: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const table: DatabaseTable =
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
};
|
||||
|
||||
const result = {
|
||||
rowCount: 1,
|
||||
columnInfo: [],
|
||||
rows: [[
|
||||
{
|
||||
displayValue: 'c1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'int',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
], [
|
||||
{
|
||||
displayValue: 'c2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
},
|
||||
{
|
||||
displayValue: 'varchar',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}
|
||||
]]
|
||||
};
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
|
||||
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object);
|
||||
const actual = await service.getTableColumnsList(table);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('generatePredictScript should generate the script successfully using model', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const predictParams: PredictParameters = {
|
||||
inputColumns: [
|
||||
{
|
||||
paramName: 'p1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
dataType: 'varchar',
|
||||
columnName: ''
|
||||
}
|
||||
],
|
||||
outputColumns: [
|
||||
{
|
||||
paramName: 'o1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
],
|
||||
databaseName: '',
|
||||
tableName: '',
|
||||
schema: ''
|
||||
};
|
||||
const model: ImportedModel =
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1',
|
||||
table: testContext.importTable
|
||||
};
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
fileName: '',
|
||||
isUntitled: true,
|
||||
languageId: 'sql',
|
||||
version: 1,
|
||||
isDirty: true,
|
||||
isClosed: false,
|
||||
save: undefined!,
|
||||
eol: undefined!,
|
||||
lineCount: 1,
|
||||
lineAt: undefined!,
|
||||
offsetAt: undefined!,
|
||||
positionAt: undefined!,
|
||||
getText: undefined!,
|
||||
getWordRangeAtPosition: undefined!,
|
||||
validateRange: undefined!,
|
||||
validatePosition: undefined!
|
||||
};
|
||||
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
|
||||
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
|
||||
|
||||
const actual = await service.generatePredictScript(predictParams, model, undefined);
|
||||
should.notEqual(actual, undefined);
|
||||
should.equal(actual.indexOf('FROM PREDICT(MODEL = @model') > 0, true);
|
||||
});
|
||||
|
||||
it('generatePredictScript should generate the script successfully using file', async function (): Promise<void> {
|
||||
const testContext = createContext();
|
||||
const connection = new azdata.connection.ConnectionProfile();
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
const predictParams: PredictParameters = {
|
||||
inputColumns: [
|
||||
{
|
||||
paramName: 'p1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
{
|
||||
paramName: 'p2',
|
||||
dataType: 'varchar',
|
||||
columnName: ''
|
||||
}
|
||||
],
|
||||
outputColumns: [
|
||||
{
|
||||
paramName: 'o1',
|
||||
dataType: 'int',
|
||||
columnName: ''
|
||||
},
|
||||
],
|
||||
databaseName: '',
|
||||
tableName: '',
|
||||
schema: ''
|
||||
};
|
||||
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
|
||||
await fs.promises.writeFile(tempFilePath, 'test');
|
||||
|
||||
let service = new PredictService(
|
||||
testContext.apiWrapper.object,
|
||||
testContext.queryRunner.object);
|
||||
|
||||
const document: vscode.TextDocument = {
|
||||
uri: vscode.Uri.parse('file:///usr/home'),
|
||||
fileName: '',
|
||||
isUntitled: true,
|
||||
languageId: 'sql',
|
||||
version: 1,
|
||||
isDirty: true,
|
||||
isClosed: false,
|
||||
save: undefined!,
|
||||
eol: undefined!,
|
||||
lineCount: 1,
|
||||
lineAt: undefined!,
|
||||
offsetAt: undefined!,
|
||||
positionAt: undefined!,
|
||||
getText: undefined!,
|
||||
getWordRangeAtPosition: undefined!,
|
||||
validateRange: undefined!,
|
||||
validatePosition: undefined!
|
||||
};
|
||||
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
|
||||
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
|
||||
|
||||
const actual = await service.generatePredictScript(predictParams, undefined, tempFilePath);
|
||||
should.notEqual(actual, undefined);
|
||||
should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
|
||||
});
|
||||
});
|
||||
303
extensions/machine-learning/src/test/queryRunner.test.ts
Normal file
303
extensions/machine-learning/src/test/queryRunner.test.ts
Normal file
@@ -0,0 +1,303 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import * as should from 'should';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { IPackageDetails } from '../typings/notebookServices';
|
||||
|
||||
interface TestContext {
|
||||
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
queryProvider: azdata.QueryProvider;
|
||||
}
|
||||
|
||||
function createContext(): TestContext {
|
||||
return {
|
||||
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
|
||||
queryProvider: {
|
||||
providerId: '',
|
||||
cancelQuery: () => {return Promise.reject();},
|
||||
runQuery: () => {return Promise.reject();},
|
||||
runQueryStatement: () => {return Promise.reject();},
|
||||
runQueryString: () => {return Promise.reject();},
|
||||
runQueryAndReturn: () => { return Promise.reject(); },
|
||||
parseSyntax: () => {return Promise.reject();},
|
||||
getQueryRows: () => {return Promise.reject();},
|
||||
disposeQuery: () => {return Promise.reject();},
|
||||
saveResults: () => {return Promise.reject();},
|
||||
setQueryExecutionOptions: () => {return Promise.reject();},
|
||||
registerOnQueryComplete: () => {return Promise.reject();},
|
||||
registerOnBatchStart: () => {return Promise.reject();},
|
||||
registerOnBatchComplete: () => {return Promise.reject();},
|
||||
registerOnResultSetAvailable: () => {return Promise.reject();},
|
||||
registerOnResultSetUpdated: () => {return Promise.reject();},
|
||||
registerOnMessage: () => {return Promise.reject();},
|
||||
commitEdit: () => {return Promise.reject();},
|
||||
createRow: () => {return Promise.reject();},
|
||||
deleteRow: () => {return Promise.reject();},
|
||||
disposeEdit: () => {return Promise.reject();},
|
||||
initializeEdit: () => {return Promise.reject();},
|
||||
revertCell: () => {return Promise.reject();},
|
||||
revertRow: () => {return Promise.reject();},
|
||||
updateCell: () => {return Promise.reject();},
|
||||
getEditRows: () => {return Promise.reject();},
|
||||
registerOnEditSessionReady: () => {return Promise.reject();},
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
describe('Query Runner', () => {
|
||||
it('getPythonPackages Should return empty list if not provider found', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
let queryProvider: azdata.QueryProvider;
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => queryProvider);
|
||||
|
||||
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
|
||||
should.deepEqual(actual, []);
|
||||
});
|
||||
|
||||
it('getPythonPackages Should return empty list if not provider throws', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.reject(); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
|
||||
should.deepEqual(actual, []);
|
||||
});
|
||||
|
||||
it('getPythonPackages Should return list if provider runs the query successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [
|
||||
[{
|
||||
displayValue: 'p1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}, {
|
||||
displayValue: '1.1.1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}],
|
||||
[{
|
||||
displayValue: 'p2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}, {
|
||||
displayValue: '1.1.2',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}]
|
||||
];
|
||||
let expected = [
|
||||
{
|
||||
'name': 'p1',
|
||||
'version': '1.1.1'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'version': '1.1.2'
|
||||
}
|
||||
];
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
|
||||
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('getPythonPackages Should return empty list if provider return no rows', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [
|
||||
];
|
||||
let expected: IPackageDetails[] = [];
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
|
||||
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('updateExternalScriptConfig Should update config successfully', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [
|
||||
];
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
await should(queryRunner.updateExternalScriptConfig(connection, true)).resolved();
|
||||
|
||||
});
|
||||
|
||||
it('isPythonInstalled Should return true is provider returns valid result', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [
|
||||
[{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}]
|
||||
];
|
||||
let expected = true;
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.isPythonInstalled(connection);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('isPythonInstalled Should return true is provider returns 0 as result', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [
|
||||
[{
|
||||
displayValue: '0',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}]
|
||||
];
|
||||
let expected = false;
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.isPythonInstalled(connection);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('isPythonInstalled Should return false is provider returns no result', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [];
|
||||
let expected = false;
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.isPythonInstalled(connection);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('isMachineLearningServiceEnabled Should return true is provider returns valid result', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [
|
||||
[{
|
||||
displayValue: '1',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}]
|
||||
];
|
||||
let expected = true;
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('isMachineLearningServiceEnabled Should return true is provider returns 0 as result', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [
|
||||
[{
|
||||
displayValue: '0',
|
||||
isNull: false,
|
||||
invariantCultureDisplayValue: ''
|
||||
}]
|
||||
];
|
||||
let expected = false;
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
it('isMachineLearningServiceEnabled Should return false is provider returns no result', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let rows: azdata.DbCellValue[][] = [];
|
||||
let expected = false;
|
||||
|
||||
let result : azdata.SimpleExecuteResult = {
|
||||
rowCount: 2,
|
||||
columnInfo: [],
|
||||
rows: rows,
|
||||
};
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
|
||||
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
|
||||
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
|
||||
|
||||
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
|
||||
should.deepEqual(actual, expected);
|
||||
});
|
||||
|
||||
});
|
||||
38
extensions/machine-learning/src/test/utils.ts
Normal file
38
extensions/machine-learning/src/test/utils.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
export interface TestContext {
|
||||
|
||||
outputChannel: vscode.OutputChannel;
|
||||
op: azdata.BackgroundOperation;
|
||||
getOpStatus: () => azdata.TaskStatus;
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
let opStatus: azdata.TaskStatus;
|
||||
|
||||
return {
|
||||
outputChannel: {
|
||||
name: '',
|
||||
append: () => { },
|
||||
appendLine: () => { },
|
||||
clear: () => { },
|
||||
show: () => { },
|
||||
hide: () => { },
|
||||
dispose: () => { }
|
||||
},
|
||||
op: {
|
||||
updateStatus: (status: azdata.TaskStatus) => {
|
||||
opStatus = status;
|
||||
},
|
||||
id: '',
|
||||
onCanceled: new vscode.EventEmitter<void>().event,
|
||||
},
|
||||
getOpStatus: () => { return opStatus; }
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { createViewContext } from './utils';
|
||||
import { DashboardWidget } from '../../views/widgets/dashboardWidget';
|
||||
|
||||
interface TestContext {
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
view: azdata.ModelView;
|
||||
onClick: vscode.EventEmitter<any>;
|
||||
}
|
||||
|
||||
|
||||
function createContext(): TestContext {
|
||||
|
||||
let viewTestContext = createViewContext();
|
||||
|
||||
return {
|
||||
apiWrapper: viewTestContext.apiWrapper,
|
||||
view: viewTestContext.view,
|
||||
onClick: viewTestContext.onClick
|
||||
};
|
||||
}
|
||||
|
||||
describe('Dashboard widget', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '');
|
||||
dashboard.register();
|
||||
testContext.onClick.fire(undefined);
|
||||
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce());
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,120 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext, ParentDialog } from './utils';
|
||||
import { AddEditLanguageTab } from '../../../views/externalLanguages/addEditLanguageTab';
|
||||
import { LanguageUpdateModel } from '../../../views/externalLanguages/languageViewBase';
|
||||
|
||||
describe('Add Edit External Languages Tab', () => {
|
||||
it('Should create AddEditLanguageTab for new language successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
let languageUpdateModel: LanguageUpdateModel = {
|
||||
content: parent.createNewContent(),
|
||||
language: parent.createNewLanguage(),
|
||||
newLang: true
|
||||
};
|
||||
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
|
||||
should.notEqual(tab.languageView, undefined, 'Failed to create language view for add');
|
||||
});
|
||||
|
||||
it('Should create AddEditLanguageTab for edit successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
let languageUpdateModel: LanguageUpdateModel = {
|
||||
content: {
|
||||
extensionFileName: 'filename',
|
||||
isLocalFile: true,
|
||||
pathToExtension: 'path',
|
||||
},
|
||||
language: {
|
||||
name: 'name',
|
||||
contents: []
|
||||
},
|
||||
newLang: false
|
||||
};
|
||||
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
|
||||
should.notEqual(tab.languageView, undefined, 'Failed to create language view for edit');
|
||||
should.equal(tab.saveButton, undefined);
|
||||
});
|
||||
|
||||
it('Should reset AddEditLanguageTab successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
let languageUpdateModel: LanguageUpdateModel = {
|
||||
content: {
|
||||
extensionFileName: 'filename',
|
||||
isLocalFile: true,
|
||||
pathToExtension: 'path',
|
||||
},
|
||||
language: {
|
||||
name: 'name',
|
||||
contents: []
|
||||
},
|
||||
newLang: false
|
||||
};
|
||||
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
|
||||
if (tab.languageName) {
|
||||
tab.languageName.value = 'some value';
|
||||
}
|
||||
await tab.reset();
|
||||
should.equal(tab.languageName?.value, 'name');
|
||||
});
|
||||
|
||||
it('Should load content successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
let languageUpdateModel: LanguageUpdateModel = {
|
||||
content: {
|
||||
extensionFileName: 'filename',
|
||||
isLocalFile: true,
|
||||
pathToExtension: 'path',
|
||||
environmentVariables: 'env vars',
|
||||
parameters: 'params'
|
||||
},
|
||||
language: {
|
||||
name: 'name',
|
||||
contents: []
|
||||
},
|
||||
newLang: false
|
||||
};
|
||||
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
|
||||
let content = tab.languageView?.updatedContent;
|
||||
should.notEqual(content, undefined);
|
||||
if (content) {
|
||||
should.equal(content.extensionFileName, languageUpdateModel.content.extensionFileName);
|
||||
should.equal(content.pathToExtension, languageUpdateModel.content.pathToExtension);
|
||||
should.equal(content.environmentVariables, languageUpdateModel.content.environmentVariables);
|
||||
should.equal(content.parameters, languageUpdateModel.content.parameters);
|
||||
}
|
||||
});
|
||||
|
||||
it('Should raise save event if save button clicked ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
let languageUpdateModel: LanguageUpdateModel = {
|
||||
content: parent.createNewContent(),
|
||||
language: parent.createNewLanguage(),
|
||||
newLang: true
|
||||
};
|
||||
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
|
||||
should.notEqual(tab.saveButton, undefined);
|
||||
let updateCalled = false;
|
||||
let promise = new Promise(resolve => {
|
||||
parent.onUpdate(() => {
|
||||
updateCalled = true;
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
testContext.onClick.fire(undefined);
|
||||
parent.onUpdatedLanguage(languageUpdateModel);
|
||||
await promise;
|
||||
should.equal(updateCalled, true);
|
||||
should.notEqual(tab.updatedData, undefined);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,104 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { createContext } from './utils';
|
||||
import { LanguageController } from '../../../views/externalLanguages/languageController';
|
||||
import * as mssql from '../../../../../mssql';
|
||||
|
||||
describe('External Languages Controller', () => {
|
||||
it('Should open dialog for manage languages successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
|
||||
let dialog = await controller.manageLanguages();
|
||||
testContext.apiWrapper.verify(x => x.openDialog(TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
should.notEqual(dialog, undefined);
|
||||
});
|
||||
|
||||
it('Should list languages successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let languages: mssql.ExternalLanguage[] = [{
|
||||
name: '',
|
||||
contents: [{
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
}]
|
||||
}];
|
||||
|
||||
testContext.dialogModel.setup( x=> x.getLanguageList()).returns(() => Promise.resolve(languages));
|
||||
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
|
||||
let dialog = await controller.manageLanguages();
|
||||
let actual = await dialog.listLanguages();
|
||||
should.deepEqual(actual, languages);
|
||||
});
|
||||
|
||||
it('Should update languages successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let language: mssql.ExternalLanguage = {
|
||||
name: '',
|
||||
contents: [{
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
}]
|
||||
};
|
||||
|
||||
testContext.dialogModel.setup( x=> x.updateLanguage(language)).returns(() => Promise.resolve());
|
||||
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
|
||||
let dialog = await controller.manageLanguages();
|
||||
await dialog.updateLanguage({
|
||||
language: language,
|
||||
content: language.contents[0],
|
||||
newLang: false
|
||||
});
|
||||
testContext.dialogModel.verify(x => x.updateLanguage(TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
|
||||
it('Should delete language successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let language: mssql.ExternalLanguage = {
|
||||
name: '',
|
||||
contents: [{
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
}]
|
||||
};
|
||||
|
||||
testContext.dialogModel.setup( x=> x.deleteLanguage(language.name)).returns(() => Promise.resolve());
|
||||
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
|
||||
let dialog = await controller.manageLanguages();
|
||||
await dialog.deleteLanguage({
|
||||
language: language,
|
||||
content: language.contents[0],
|
||||
newLang: false
|
||||
});
|
||||
testContext.dialogModel.verify(x => x.deleteLanguage(TypeMoq.It.isAny()), TypeMoq.Times.once());
|
||||
});
|
||||
|
||||
it('Should open edit dialog for edit language', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let language: mssql.ExternalLanguage = {
|
||||
name: '',
|
||||
contents: [{
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
}]
|
||||
};
|
||||
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
|
||||
let dialog = await controller.manageLanguages();
|
||||
dialog.onEditLanguage({
|
||||
language: language,
|
||||
content: language.contents[0],
|
||||
newLang: false
|
||||
});
|
||||
testContext.apiWrapper.verify(x => x.openDialog(TypeMoq.It.isAny()), TypeMoq.Times.exactly(2));
|
||||
should.notEqual(dialog, undefined);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,50 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext, ParentDialog } from './utils';
|
||||
import { LanguageEditDialog } from '../../../views/externalLanguages/languageEditDialog';
|
||||
import { LanguageUpdateModel } from '../../../views/externalLanguages/languageViewBase';
|
||||
|
||||
describe('Edit External Languages Dialog', () => {
|
||||
it('Should open dialog successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
let languageUpdateModel: LanguageUpdateModel = {
|
||||
content: parent.createNewContent(),
|
||||
language: parent.createNewLanguage(),
|
||||
newLang: true
|
||||
};
|
||||
let dialog = new LanguageEditDialog(testContext.apiWrapper.object, parent, languageUpdateModel);
|
||||
dialog.showDialog();
|
||||
should.notEqual(dialog.addNewLanguageTab, undefined);
|
||||
});
|
||||
|
||||
it('Should raise save event if save button clicked ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
let languageUpdateModel: LanguageUpdateModel = {
|
||||
content: parent.createNewContent(),
|
||||
language: parent.createNewLanguage(),
|
||||
newLang: true
|
||||
};
|
||||
let dialog = new LanguageEditDialog(testContext.apiWrapper.object, parent, languageUpdateModel);
|
||||
dialog.showDialog();
|
||||
|
||||
let updateCalled = false;
|
||||
let promise = new Promise(resolve => {
|
||||
parent.onUpdate(() => {
|
||||
updateCalled = true;
|
||||
parent.onUpdatedLanguage(languageUpdateModel);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
||||
dialog.onSave();
|
||||
await promise;
|
||||
should.equal(updateCalled, true);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,19 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { LanguagesDialog } from '../../../views/externalLanguages/languagesDialog';
|
||||
|
||||
describe('External Languages Dialog', () => {
|
||||
it('Should open dialog successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let dialog = new LanguagesDialog(testContext.apiWrapper.object, '');
|
||||
dialog.showDialog();
|
||||
should.notEqual(dialog.addNewLanguageTab, undefined);
|
||||
should.notEqual(dialog.currentLanguagesTab, undefined);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,61 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import * as mssql from '../../../../../mssql';
|
||||
import { LanguageService } from '../../../externalLanguage/languageService';
|
||||
|
||||
describe('External Languages Dialog Model', () => {
|
||||
it('Should list languages successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let languages: mssql.ExternalLanguage[] = [{
|
||||
name: '',
|
||||
contents: [{
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
}]
|
||||
}];
|
||||
testContext.languageExtensionService.listLanguages = () => {return Promise.resolve(languages);};
|
||||
let model = new LanguageService(testContext.apiWrapper.object, testContext.languageExtensionService);
|
||||
await model.load();
|
||||
let actual = await model.getLanguageList();
|
||||
should.deepEqual(actual, languages);
|
||||
});
|
||||
|
||||
it('Should update language successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let language: mssql.ExternalLanguage = {
|
||||
name: '',
|
||||
contents: [{
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
}]
|
||||
};
|
||||
|
||||
let model = new LanguageService(testContext.apiWrapper.object, testContext.languageExtensionService);
|
||||
await model.load();
|
||||
await should(model.updateLanguage(language)).resolved();
|
||||
});
|
||||
|
||||
it('Should delete language successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let language: mssql.ExternalLanguage = {
|
||||
name: '',
|
||||
contents: [{
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
}]
|
||||
};
|
||||
|
||||
let model = new LanguageService(testContext.apiWrapper.object, testContext.languageExtensionService);
|
||||
await model.load();
|
||||
await should(model.deleteLanguage(language.name)).resolved();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,53 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { LanguageViewBase } from '../../../views/externalLanguages/languageViewBase';
|
||||
import * as mssql from '../../../../../mssql';
|
||||
import { LanguageService } from '../../../externalLanguage/languageService';
|
||||
import { createViewContext } from '../utils';
|
||||
|
||||
export interface TestContext {
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
view: azdata.ModelView;
|
||||
languageExtensionService: mssql.ILanguageExtensionService;
|
||||
onClick: vscode.EventEmitter<any>;
|
||||
dialogModel: TypeMoq.IMock<LanguageService>;
|
||||
}
|
||||
|
||||
export class ParentDialog extends LanguageViewBase {
|
||||
public reset(): Promise<void> {
|
||||
return Promise.resolve();
|
||||
}
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper) {
|
||||
super(apiWrapper, '');
|
||||
}
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
|
||||
let viewTestContext = createViewContext();
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
viewTestContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
viewTestContext.apiWrapper.setup(x => x.getUriForConnection(TypeMoq.It.isAny())).returns(() => { return Promise.resolve('connectionUrl'); });
|
||||
|
||||
let languageExtensionService: mssql.ILanguageExtensionService = {
|
||||
listLanguages: () => { return Promise.resolve([]); },
|
||||
deleteLanguage: () => { return Promise.resolve(); },
|
||||
updateLanguage: () => { return Promise.resolve(); }
|
||||
};
|
||||
|
||||
return {
|
||||
apiWrapper: viewTestContext.apiWrapper,
|
||||
view: viewTestContext.view,
|
||||
languageExtensionService: languageExtensionService,
|
||||
onClick: viewTestContext.onClick,
|
||||
dialogModel: TypeMoq.Mock.ofType(LanguageService)
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { ModelManagementController } from '../../../views/models/modelManagementController';
|
||||
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||
import { DeleteModelEventName, UpdateModelEventName } from '../../../views/models/modelViewBase';
|
||||
import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog';
|
||||
|
||||
const accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
accountId: '1',
|
||||
providerId: ''
|
||||
},
|
||||
displayInfo: {
|
||||
displayName: 'account',
|
||||
userId: '',
|
||||
accountType: '',
|
||||
contextualDisplayName: ''
|
||||
},
|
||||
isStale: false,
|
||||
properties: []
|
||||
}
|
||||
];
|
||||
const subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||
{
|
||||
name: 'subscription',
|
||||
id: '2'
|
||||
}
|
||||
];
|
||||
const groups: azureResource.AzureResourceResourceGroup[] = [
|
||||
{
|
||||
name: 'group',
|
||||
id: '3'
|
||||
}
|
||||
];
|
||||
const workspaces: Workspace[] = [
|
||||
{
|
||||
name: 'workspace',
|
||||
id: '4'
|
||||
}
|
||||
];
|
||||
const models: WorkspaceModel[] = [
|
||||
{
|
||||
id: '5',
|
||||
name: 'model'
|
||||
}
|
||||
];
|
||||
const localModels: ImportedModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'model',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
|
||||
const dbNames: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const tableNames: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
const columnNames: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const modelParameters: ModelParameters = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
}
|
||||
]
|
||||
};
|
||||
describe('Model Controller', () => {
|
||||
|
||||
it('Should open deploy model wizard successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
|
||||
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
|
||||
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
|
||||
databaseName: 'db',
|
||||
tableName: 'table',
|
||||
schema: 'dbo'
|
||||
}));
|
||||
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
|
||||
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames));
|
||||
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames));
|
||||
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
|
||||
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
|
||||
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
|
||||
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
|
||||
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
|
||||
|
||||
const view = await controller.registerModel(undefined);
|
||||
should.notEqual(view, undefined);
|
||||
});
|
||||
|
||||
it('Should open predict wizard successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
|
||||
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
|
||||
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
|
||||
databaseName: 'db',
|
||||
tableName: 'table',
|
||||
schema: 'dbo'
|
||||
}));
|
||||
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
|
||||
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve([
|
||||
'db', 'db1'
|
||||
]));
|
||||
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve([
|
||||
{ tableName: 'tb', databaseName: 'db', schema: 'dbo' }
|
||||
]));
|
||||
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
|
||||
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
|
||||
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
|
||||
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
|
||||
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
|
||||
testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames));
|
||||
testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters));
|
||||
testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
|
||||
testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
|
||||
|
||||
const view = await controller.predictModel();
|
||||
should.notEqual(view, undefined);
|
||||
});
|
||||
|
||||
it('Should open edit model dialog successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
testContext.deployModelService.setup(x => x.deleteModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
|
||||
|
||||
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
|
||||
const model: ImportedModel =
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
};
|
||||
const view = <EditModelDialog>await controller.editModel(model);
|
||||
should.notEqual(view?.editModelPage, undefined);
|
||||
if (view.editModelPage) {
|
||||
view.editModelPage.sendRequest(UpdateModelEventName, model);
|
||||
view.editModelPage.sendRequest(DeleteModelEventName, model);
|
||||
}
|
||||
testContext.deployModelService.verify(x => x.updateModel(model), TypeMoq.Times.atLeastOnce());
|
||||
testContext.deployModelService.verify(x => x.deleteModel(model), TypeMoq.Times.atLeastOnce());
|
||||
|
||||
should.notEqual(view, undefined);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,102 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext, ParentDialog } from './utils';
|
||||
import { AzureModelsComponent } from '../../../views/models/azureModelsComponent';
|
||||
import { ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName } from '../../../views/models/modelViewBase';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
|
||||
describe('Azure Models Component', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
|
||||
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent);
|
||||
view.registerComponent(testContext.view.modelBuilder);
|
||||
should.notEqual(view.component, undefined);
|
||||
});
|
||||
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let parent = new ParentDialog(testContext.apiWrapper.object);
|
||||
|
||||
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent, false);
|
||||
view.registerComponent(testContext.view.modelBuilder);
|
||||
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
accountId: '1',
|
||||
providerId: ''
|
||||
},
|
||||
displayInfo: {
|
||||
displayName: 'account',
|
||||
userId: '',
|
||||
accountType: '',
|
||||
contextualDisplayName: ''
|
||||
},
|
||||
isStale: false,
|
||||
properties: []
|
||||
}
|
||||
];
|
||||
let subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||
{
|
||||
name: 'subscription',
|
||||
id: '2'
|
||||
}
|
||||
];
|
||||
let groups: azureResource.AzureResourceResourceGroup[] = [
|
||||
{
|
||||
name: 'group',
|
||||
id: '3'
|
||||
}
|
||||
];
|
||||
let workspaces: Workspace[] = [
|
||||
{
|
||||
name: 'workspace',
|
||||
id: '4'
|
||||
}
|
||||
];
|
||||
let models: WorkspaceModel[] = [
|
||||
{
|
||||
id: '5',
|
||||
name: 'model'
|
||||
}
|
||||
];
|
||||
parent.on(ListAccountsEventName, () => {
|
||||
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
||||
});
|
||||
parent.on(ListSubscriptionsEventName, () => {
|
||||
|
||||
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
|
||||
});
|
||||
parent.on(ListGroupsEventName, () => {
|
||||
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
|
||||
});
|
||||
parent.on(ListWorkspacesEventName, () => {
|
||||
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
|
||||
});
|
||||
parent.on(ListAzureModelsEventName, () => {
|
||||
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
await view.refresh();
|
||||
testContext.onClick.fire(true);
|
||||
should.notEqual(view.data, undefined);
|
||||
should.equal(view.data?.length, 1);
|
||||
if (view.data) {
|
||||
should.deepEqual(view.data[0].account, accounts[0]);
|
||||
should.deepEqual(view.data[0].subscription, subscriptions[0]);
|
||||
should.deepEqual(view.data[0].group, groups[0]);
|
||||
should.deepEqual(view.data[0].workspace, workspaces[0]);
|
||||
should.deepEqual(view.data[0].model, models[0]);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,33 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog';
|
||||
|
||||
describe('Edit Model Dialog', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
const model: ImportedModel =
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'name1',
|
||||
description: 'desc1',
|
||||
created: '2018-01-01',
|
||||
version: '1.1',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
};
|
||||
let view = new EditModelDialog(testContext.apiWrapper.object, '', undefined, model);
|
||||
view.open();
|
||||
|
||||
should.notEqual(view.dialogView, undefined);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,203 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import {
|
||||
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
|
||||
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType
|
||||
}
|
||||
from '../../../views/models/modelViewBase';
|
||||
import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { PredictWizard } from '../../../views/models/prediction/predictWizard';
|
||||
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||
|
||||
describe('Predict Wizard', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||
await view.open();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
});
|
||||
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new PredictWizard(testContext.apiWrapper.object, '');
|
||||
view.importTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
};
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
accountId: '1',
|
||||
providerId: ''
|
||||
},
|
||||
displayInfo: {
|
||||
displayName: 'account',
|
||||
userId: '',
|
||||
accountType: '',
|
||||
contextualDisplayName: ''
|
||||
},
|
||||
isStale: false,
|
||||
properties: []
|
||||
}
|
||||
];
|
||||
let subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||
{
|
||||
name: 'subscription',
|
||||
id: '2'
|
||||
}
|
||||
];
|
||||
let groups: azureResource.AzureResourceResourceGroup[] = [
|
||||
{
|
||||
name: 'group',
|
||||
id: '3'
|
||||
}
|
||||
];
|
||||
let workspaces: Workspace[] = [
|
||||
{
|
||||
name: 'workspace',
|
||||
id: '4'
|
||||
}
|
||||
];
|
||||
let models: WorkspaceModel[] = [
|
||||
{
|
||||
id: '5',
|
||||
name: 'model'
|
||||
}
|
||||
];
|
||||
let localModels: ImportedModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'model',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
const dbNames: string[] = [
|
||||
'db1',
|
||||
'db2'
|
||||
];
|
||||
const tableNames: DatabaseTable[] = [
|
||||
{
|
||||
databaseName: 'db1',
|
||||
schema: 'dbo',
|
||||
tableName: 'tb1'
|
||||
},
|
||||
{
|
||||
databaseName: 'db1',
|
||||
tableName: 'tb2',
|
||||
schema: 'dbo'
|
||||
}
|
||||
];
|
||||
const columnNames: TableColumn[] = [
|
||||
{
|
||||
columnName: 'c1',
|
||||
dataType: 'int'
|
||||
},
|
||||
{
|
||||
columnName: 'c2',
|
||||
dataType: 'varchar'
|
||||
}
|
||||
];
|
||||
const modelParameters: ModelParameters = {
|
||||
inputs: [
|
||||
{
|
||||
'name': 'p1',
|
||||
'type': 'int'
|
||||
},
|
||||
{
|
||||
'name': 'p2',
|
||||
'type': 'varchar'
|
||||
}
|
||||
],
|
||||
outputs: [
|
||||
{
|
||||
'name': 'o1',
|
||||
'type': 'int'
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
view.on(ListModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
|
||||
});
|
||||
view.on(ListAccountsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
||||
});
|
||||
view.on(ListSubscriptionsEventName, () => {
|
||||
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
|
||||
});
|
||||
view.on(ListGroupsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
|
||||
});
|
||||
view.on(ListWorkspacesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
|
||||
});
|
||||
view.on(ListAzureModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames });
|
||||
});
|
||||
view.on(ListTableNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames });
|
||||
});
|
||||
view.on(ListColumnNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames });
|
||||
});
|
||||
view.on(LoadModelParametersEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters });
|
||||
});
|
||||
view.on(DownloadAzureModelEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' });
|
||||
});
|
||||
view.on(DownloadRegisteredModelEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
|
||||
});
|
||||
if (view.modelBrowsePage) {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
|
||||
}
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data, undefined);
|
||||
|
||||
if (view.modelBrowsePage) {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
|
||||
}
|
||||
await view.refresh();
|
||||
testContext.onClick.fire(undefined);
|
||||
|
||||
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined);
|
||||
if (view.modelBrowsePage?.registeredModelsComponent?.data) {
|
||||
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1);
|
||||
}
|
||||
|
||||
|
||||
should.notEqual(await view.getModelFileName(), undefined);
|
||||
await view.columnsSelectionPage?.onEnter();
|
||||
|
||||
should.notEqual(view.columnsSelectionPage?.data, undefined);
|
||||
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name);
|
||||
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,131 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
import { azureResource } from '../../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
import { WorkspaceModel } from '../../../modelManagement/interfaces';
|
||||
import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard';
|
||||
|
||||
describe('Register Model Wizard', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
|
||||
view.importTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'table',
|
||||
schema: 'dbo'
|
||||
};
|
||||
await view.open();
|
||||
should.notEqual(view.wizardView, undefined);
|
||||
should.notEqual(view.modelSourcePage, undefined);
|
||||
});
|
||||
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
|
||||
view.importTable = {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
};
|
||||
await view.open();
|
||||
let accounts: azdata.Account[] = [
|
||||
{
|
||||
key: {
|
||||
accountId: '1',
|
||||
providerId: ''
|
||||
},
|
||||
displayInfo: {
|
||||
displayName: 'account',
|
||||
userId: '',
|
||||
accountType: '',
|
||||
contextualDisplayName: ''
|
||||
},
|
||||
isStale: false,
|
||||
properties: []
|
||||
}
|
||||
];
|
||||
let subscriptions: azureResource.AzureResourceSubscription[] = [
|
||||
{
|
||||
name: 'subscription',
|
||||
id: '2'
|
||||
}
|
||||
];
|
||||
let groups: azureResource.AzureResourceResourceGroup[] = [
|
||||
{
|
||||
name: 'group',
|
||||
id: '3'
|
||||
}
|
||||
];
|
||||
let workspaces: Workspace[] = [
|
||||
{
|
||||
name: 'workspace',
|
||||
id: '4'
|
||||
}
|
||||
];
|
||||
let models: WorkspaceModel[] = [
|
||||
{
|
||||
id: '5',
|
||||
name: 'model'
|
||||
}
|
||||
];
|
||||
let localModels: ImportedModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'model',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [
|
||||
'db', 'db1'
|
||||
] });
|
||||
});
|
||||
view.on(ListTableNamesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [
|
||||
'tb', 'tb1'
|
||||
] });
|
||||
});
|
||||
view.on(ListAccountsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
|
||||
});
|
||||
view.on(ListSubscriptionsEventName, () => {
|
||||
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
|
||||
});
|
||||
view.on(ListGroupsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
|
||||
});
|
||||
view.on(ListWorkspacesEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
|
||||
});
|
||||
view.on(ListAzureModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
|
||||
});
|
||||
|
||||
if (view.modelBrowsePage) {
|
||||
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
|
||||
}
|
||||
await view.refresh();
|
||||
should.notEqual(view.azureModelsComponent?.data ,undefined);
|
||||
should.notEqual(view.localModelsComponent?.data, undefined);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,46 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as should from 'should';
|
||||
import 'mocha';
|
||||
import { createContext } from './utils';
|
||||
import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog';
|
||||
import { ListModelsEventName } from '../../../views/models/modelViewBase';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
import { ViewBase } from '../../../views/viewBase';
|
||||
|
||||
describe('Registered Models Dialog', () => {
|
||||
it('Should create view components successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
|
||||
should.notEqual(view.dialogView, undefined);
|
||||
should.notEqual(view.currentLanguagesTab, undefined);
|
||||
});
|
||||
|
||||
it('Should load data successfully ', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
|
||||
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
|
||||
view.open();
|
||||
let models: ImportedModel[] = [
|
||||
{
|
||||
id: 1,
|
||||
modelName: 'model',
|
||||
table: {
|
||||
databaseName: 'db',
|
||||
tableName: 'tb',
|
||||
schema: 'dbo'
|
||||
}
|
||||
}
|
||||
];
|
||||
view.on(ListModelsEventName, () => {
|
||||
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: models });
|
||||
});
|
||||
await view.refresh();
|
||||
});
|
||||
});
|
||||
50
extensions/machine-learning/src/test/views/models/utils.ts
Normal file
50
extensions/machine-learning/src/test/views/models/utils.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { createViewContext } from '../utils';
|
||||
import { ModelViewBase } from '../../../views/models/modelViewBase';
|
||||
import { AzureModelRegistryService } from '../../../modelManagement/azureModelRegistryService';
|
||||
import { DeployedModelService } from '../../../modelManagement/deployedModelService';
|
||||
import { PredictService } from '../../../prediction/predictService';
|
||||
|
||||
export interface TestContext {
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
view: azdata.ModelView;
|
||||
onClick: vscode.EventEmitter<any>;
|
||||
azureModelService: TypeMoq.IMock<AzureModelRegistryService>;
|
||||
deployModelService: TypeMoq.IMock<DeployedModelService>;
|
||||
predictService: TypeMoq.IMock<PredictService>;
|
||||
}
|
||||
|
||||
export class ParentDialog extends ModelViewBase {
|
||||
public refresh(): Promise<void> {
|
||||
return Promise.resolve();
|
||||
}
|
||||
public reset(): Promise<void> {
|
||||
return Promise.resolve();
|
||||
}
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper) {
|
||||
super(apiWrapper, '');
|
||||
}
|
||||
}
|
||||
|
||||
export function createContext(): TestContext {
|
||||
|
||||
let viewTestContext = createViewContext();
|
||||
|
||||
return {
|
||||
apiWrapper: viewTestContext.apiWrapper,
|
||||
view: viewTestContext.view,
|
||||
onClick: viewTestContext.onClick,
|
||||
azureModelService: TypeMoq.Mock.ofType(AzureModelRegistryService),
|
||||
deployModelService: TypeMoq.Mock.ofType(DeployedModelService),
|
||||
predictService: TypeMoq.Mock.ofType(PredictService)
|
||||
};
|
||||
}
|
||||
333
extensions/machine-learning/src/test/views/utils.ts
Normal file
333
extensions/machine-learning/src/test/views/utils.ts
Normal file
@@ -0,0 +1,333 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export interface ViewTestContext {
|
||||
apiWrapper: TypeMoq.IMock<ApiWrapper>;
|
||||
view: azdata.ModelView;
|
||||
onClick: vscode.EventEmitter<any>;
|
||||
}
|
||||
|
||||
export function createViewContext(): ViewTestContext {
|
||||
let onClick: vscode.EventEmitter<any> = new vscode.EventEmitter<any>();
|
||||
|
||||
let apiWrapper = TypeMoq.Mock.ofType(ApiWrapper);
|
||||
let componentBase: azdata.Component = {
|
||||
id: '',
|
||||
updateProperties: () => Promise.resolve(),
|
||||
updateProperty: () => Promise.resolve(),
|
||||
updateCssStyles: undefined!,
|
||||
onValidityChanged: undefined!,
|
||||
valid: true,
|
||||
validate: undefined!,
|
||||
focus: undefined!
|
||||
};
|
||||
let button: azdata.ButtonComponent = Object.assign({}, componentBase, {
|
||||
onDidClick: onClick.event
|
||||
});
|
||||
let link: azdata.HyperlinkComponent = Object.assign({}, componentBase, {
|
||||
onDidClick: onClick.event,
|
||||
label: '',
|
||||
url: ''
|
||||
});
|
||||
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
|
||||
checked: true,
|
||||
onDidClick: onClick.event
|
||||
});
|
||||
let checkbox: azdata.CheckBoxComponent = Object.assign({}, componentBase, {
|
||||
checked: true,
|
||||
onChanged: onClick.event
|
||||
});
|
||||
let container = {
|
||||
clearItems: () => { },
|
||||
addItems: () => { },
|
||||
addItem: () => { },
|
||||
removeItem: () => true,
|
||||
insertItem: () => { },
|
||||
items: [],
|
||||
setLayout: () => { }
|
||||
};
|
||||
let form: azdata.FormContainer = Object.assign({}, componentBase, container, {
|
||||
});
|
||||
let flex: azdata.FlexContainer = Object.assign({}, componentBase, container, {
|
||||
});
|
||||
let div: azdata.DivContainer = Object.assign({}, componentBase, container, {
|
||||
onDidClick: onClick.event
|
||||
});
|
||||
|
||||
let buttonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
|
||||
component: () => button,
|
||||
withProperties: () => buttonBuilder,
|
||||
withValidation: () => buttonBuilder
|
||||
};
|
||||
let hyperLinkBuilder: azdata.ComponentBuilder<azdata.HyperlinkComponent> = {
|
||||
component: () => link,
|
||||
withProperties: () => hyperLinkBuilder,
|
||||
withValidation: () => hyperLinkBuilder
|
||||
};
|
||||
let radioButtonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
|
||||
component: () => radioButton,
|
||||
withProperties: () => radioButtonBuilder,
|
||||
withValidation: () => radioButtonBuilder
|
||||
};
|
||||
let checkBoxBuilder: azdata.ComponentBuilder<azdata.CheckBoxComponent> = {
|
||||
component: () => checkbox,
|
||||
withProperties: () => checkBoxBuilder,
|
||||
withValidation: () => checkBoxBuilder
|
||||
};
|
||||
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
|
||||
onTextChanged: onClick.event!,
|
||||
onEnterKeyPressed: undefined!,
|
||||
value: ''
|
||||
});
|
||||
let image: () => azdata.ImageComponent = () => Object.assign({}, componentBase, {
|
||||
|
||||
});
|
||||
let dropdown: () => azdata.DropDownComponent = () => Object.assign({}, componentBase, {
|
||||
onValueChanged: onClick.event,
|
||||
value: {
|
||||
name: '',
|
||||
displayName: ''
|
||||
},
|
||||
values: []
|
||||
});
|
||||
let declarativeTable: () => azdata.DeclarativeTableComponent = () => Object.assign({}, componentBase, {
|
||||
onDataChanged: undefined!,
|
||||
data: [],
|
||||
columns: []
|
||||
});
|
||||
|
||||
let loadingComponent: () => azdata.LoadingComponent = () => Object.assign({}, componentBase, {
|
||||
loading: false,
|
||||
component: undefined!
|
||||
});
|
||||
|
||||
let card: () => azdata.CardComponent = () => Object.assign({}, componentBase, {
|
||||
label: '',
|
||||
onDidActionClick: new vscode.EventEmitter<azdata.ActionDescriptor>().event,
|
||||
onCardSelectedChanged: onClick.event
|
||||
});
|
||||
|
||||
let declarativeTableBuilder: azdata.ComponentBuilder<azdata.DeclarativeTableComponent> = {
|
||||
component: () => declarativeTable(),
|
||||
withProperties: () => declarativeTableBuilder,
|
||||
withValidation: () => declarativeTableBuilder
|
||||
};
|
||||
|
||||
let loadingBuilder: azdata.LoadingComponentBuilder = {
|
||||
component: () => loadingComponent(),
|
||||
withProperties: () => loadingBuilder,
|
||||
withValidation: () => loadingBuilder,
|
||||
withItem: () => loadingBuilder
|
||||
};
|
||||
|
||||
let formBuilder: azdata.FormBuilder = Object.assign({}, {
|
||||
component: () => form,
|
||||
addFormItem: () => { },
|
||||
insertFormItem: () => { },
|
||||
removeFormItem: () => true,
|
||||
addFormItems: () => { },
|
||||
withFormItems: () => formBuilder,
|
||||
withProperties: () => formBuilder,
|
||||
withValidation: () => formBuilder,
|
||||
withItems: () => formBuilder,
|
||||
withLayout: () => formBuilder
|
||||
});
|
||||
|
||||
let flexBuilder: azdata.FlexBuilder = Object.assign({}, {
|
||||
component: () => flex,
|
||||
withProperties: () => flexBuilder,
|
||||
withValidation: () => flexBuilder,
|
||||
withItems: () => flexBuilder,
|
||||
withLayout: () => flexBuilder
|
||||
});
|
||||
let divBuilder: azdata.DivBuilder = Object.assign({}, {
|
||||
component: () => div,
|
||||
withProperties: () => divBuilder,
|
||||
withValidation: () => divBuilder,
|
||||
withItems: () => divBuilder,
|
||||
withLayout: () => divBuilder
|
||||
});
|
||||
|
||||
let inputBoxBuilder: azdata.ComponentBuilder<azdata.InputBoxComponent> = {
|
||||
component: () => {
|
||||
let r = inputBox();
|
||||
return r;
|
||||
},
|
||||
withProperties: () => inputBoxBuilder,
|
||||
withValidation: () => inputBoxBuilder
|
||||
};
|
||||
let cardBuilder: azdata.ComponentBuilder<azdata.CardComponent> = {
|
||||
component: () => {
|
||||
let r = card();
|
||||
return r;
|
||||
},
|
||||
withProperties: () => cardBuilder,
|
||||
withValidation: () => cardBuilder
|
||||
};
|
||||
|
||||
let imageBuilder: azdata.ComponentBuilder<azdata.ImageComponent> = {
|
||||
component: () => {
|
||||
let r = image();
|
||||
return r;
|
||||
},
|
||||
withProperties: () => imageBuilder,
|
||||
withValidation: () => imageBuilder
|
||||
};
|
||||
let dropdownBuilder: azdata.ComponentBuilder<azdata.DropDownComponent> = {
|
||||
component: () => {
|
||||
let r = dropdown();
|
||||
return r;
|
||||
},
|
||||
withProperties: () => dropdownBuilder,
|
||||
withValidation: () => dropdownBuilder
|
||||
};
|
||||
|
||||
let view: azdata.ModelView = {
|
||||
onClosed: undefined!,
|
||||
connection: undefined!,
|
||||
serverInfo: undefined!,
|
||||
valid: true,
|
||||
onValidityChanged: undefined!,
|
||||
validate: undefined!,
|
||||
initializeModel: () => { return Promise.resolve(); },
|
||||
modelBuilder: {
|
||||
radioCardGroup: undefined!,
|
||||
navContainer: undefined!,
|
||||
divContainer: () => divBuilder,
|
||||
flexContainer: () => flexBuilder,
|
||||
splitViewContainer: undefined!,
|
||||
dom: undefined!,
|
||||
card: () => cardBuilder,
|
||||
inputBox: () => inputBoxBuilder,
|
||||
checkBox: () => checkBoxBuilder!,
|
||||
radioButton: () => radioButtonBuilder,
|
||||
webView: undefined!,
|
||||
editor: undefined!,
|
||||
diffeditor: undefined!,
|
||||
text: () => inputBoxBuilder,
|
||||
image: () => imageBuilder,
|
||||
button: () => buttonBuilder,
|
||||
dropDown: () => dropdownBuilder,
|
||||
tree: undefined!,
|
||||
listBox: undefined!,
|
||||
table: undefined!,
|
||||
declarativeTable: () => declarativeTableBuilder,
|
||||
dashboardWidget: undefined!,
|
||||
dashboardWebview: undefined!,
|
||||
formContainer: () => formBuilder,
|
||||
groupContainer: undefined!,
|
||||
toolbarContainer: undefined!,
|
||||
loadingComponent: () => loadingBuilder,
|
||||
fileBrowserTree: undefined!,
|
||||
hyperlink: () => hyperLinkBuilder,
|
||||
tabbedPanel: undefined!,
|
||||
separator: undefined!,
|
||||
propertiesContainer: undefined!
|
||||
}
|
||||
};
|
||||
let tab: azdata.window.DialogTab = {
|
||||
title: '',
|
||||
content: '',
|
||||
registerContent: async (handler) => {
|
||||
try {
|
||||
await handler(view);
|
||||
} catch (err) {
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
onValidityChanged: undefined!,
|
||||
valid: true,
|
||||
modelView: undefined!
|
||||
};
|
||||
|
||||
let dialogButton: azdata.window.Button = {
|
||||
label: '',
|
||||
enabled: true,
|
||||
hidden: false,
|
||||
onClick: onClick.event,
|
||||
|
||||
};
|
||||
let dialogMessage: azdata.window.DialogMessage = {
|
||||
text: '',
|
||||
};
|
||||
let dialog: azdata.window.Dialog = {
|
||||
title: '',
|
||||
isWide: false,
|
||||
content: [],
|
||||
okButton: dialogButton,
|
||||
cancelButton: dialogButton,
|
||||
customButtons: [],
|
||||
message: dialogMessage,
|
||||
registerCloseValidator: () => { },
|
||||
registerOperation: () => { },
|
||||
onValidityChanged: new vscode.EventEmitter<boolean>().event,
|
||||
registerContent: () => { },
|
||||
modelView: undefined!,
|
||||
valid: true
|
||||
};
|
||||
let wizard: azdata.window.Wizard = {
|
||||
title: '',
|
||||
pages: [],
|
||||
currentPage: 0,
|
||||
doneButton: dialogButton,
|
||||
cancelButton: dialogButton,
|
||||
generateScriptButton: dialogButton,
|
||||
nextButton: dialogButton,
|
||||
backButton: dialogButton,
|
||||
customButtons: [],
|
||||
displayPageTitles: true,
|
||||
onPageChanged: onClick.event,
|
||||
addPage: () => { return Promise.resolve(); },
|
||||
removePage: () => { return Promise.resolve(); },
|
||||
setCurrentPage: () => { return Promise.resolve(); },
|
||||
open: () => { return Promise.resolve(); },
|
||||
close: () => { return Promise.resolve(); },
|
||||
registerNavigationValidator: () => { },
|
||||
message: dialogMessage,
|
||||
registerOperation: () => { }
|
||||
};
|
||||
let wizardPage: azdata.window.WizardPage = {
|
||||
title: '',
|
||||
content: '',
|
||||
customButtons: [],
|
||||
enabled: true,
|
||||
description: '',
|
||||
onValidityChanged: onClick.event,
|
||||
registerContent: async (handler) => {
|
||||
try {
|
||||
await handler(view);
|
||||
} catch (err) {
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
modelView: undefined!,
|
||||
valid: true
|
||||
};
|
||||
apiWrapper.setup(x => x.createButton(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => dialogButton);
|
||||
apiWrapper.setup(x => x.createTab(TypeMoq.It.isAny())).returns(() => tab);
|
||||
apiWrapper.setup(x => x.createWizard(TypeMoq.It.isAny())).returns(() => wizard);
|
||||
apiWrapper.setup(x => x.createWizardPage(TypeMoq.It.isAny())).returns(() => wizardPage);
|
||||
apiWrapper.setup(x => x.createModelViewDialog(TypeMoq.It.isAny())).returns(() => dialog);
|
||||
apiWrapper.setup(x => x.openDialog(TypeMoq.It.isAny())).returns(() => { });
|
||||
apiWrapper.setup(x => x.registerWidget(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(async (id, handler) => {
|
||||
if (id) {
|
||||
return await handler(view);
|
||||
} else {
|
||||
Promise.reject();
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
apiWrapper: apiWrapper,
|
||||
view: view,
|
||||
onClick: onClick,
|
||||
};
|
||||
}
|
||||
|
||||
22
extensions/machine-learning/src/types.ts
Normal file
22
extensions/machine-learning/src/types.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
const _typeof = {
|
||||
undefined: 'undefined'
|
||||
};
|
||||
|
||||
/**
|
||||
* @returns whether the provided parameter is undefined or null.
|
||||
*/
|
||||
export function isUndefinedOrNull(obj: any): boolean {
|
||||
return isUndefined(obj) || obj === null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns whether the provided parameter is undefined.
|
||||
*/
|
||||
export function isUndefined(obj: any): boolean {
|
||||
return typeof (obj) === _typeof.undefined;
|
||||
}
|
||||
28
extensions/machine-learning/src/typings/azure-resource.d.ts
vendored
Normal file
28
extensions/machine-learning/src/typings/azure-resource.d.ts
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { Account } from 'azdata';
|
||||
import * as msRest from '@azure/ms-rest-js';
|
||||
|
||||
export namespace azureResource {
|
||||
|
||||
export interface AzureResource {
|
||||
name: string;
|
||||
id: string;
|
||||
}
|
||||
|
||||
export interface AzureResourceSubscription extends AzureResource {
|
||||
}
|
||||
|
||||
export interface AzureResourceResourceGroup extends AzureResource {
|
||||
}
|
||||
|
||||
export interface IAzureResourceService<T extends AzureResource> {
|
||||
getResources(subscription: AzureResourceSubscription, credential: msRest.ServiceClientCredentials): Promise<T[]>;
|
||||
}
|
||||
|
||||
export type GetSubscriptionsResult = { subscriptions: AzureResourceSubscription[], errors: Error[] };
|
||||
export type GetResourceGroupsResult = { resourceGroups: AzureResourceResourceGroup[], errors: Error[] };
|
||||
}
|
||||
106
extensions/machine-learning/src/typings/notebookServices.d.ts
vendored
Normal file
106
extensions/machine-learning/src/typings/notebookServices.d.ts
vendored
Normal file
@@ -0,0 +1,106 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
/**
|
||||
* The API provided by this extension.
|
||||
*
|
||||
* @export
|
||||
*/
|
||||
export interface IExtensionApi {
|
||||
getJupyterController(): IJupyterController;
|
||||
registerPackageManager(providerId: string, packageManagerProvider: IPackageManageProvider): void
|
||||
getPackageManagers(): Map<string, IPackageManageProvider>
|
||||
}
|
||||
|
||||
export interface IJupyterController {
|
||||
jupyterInstallation: IJupyterServerInstallation;
|
||||
}
|
||||
|
||||
export interface IJupyterServerInstallation {
|
||||
installPipPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise<void>;
|
||||
uninstallPipPackages(packages: IPackageDetails[]): Promise<void>;
|
||||
installCondaPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise<void>;
|
||||
uninstallCondaPackages(packages: IPackageDetails[]): Promise<void>;
|
||||
getInstalledPipPackages(): Promise<IPackageDetails[]>;
|
||||
pythonExecutable: string;
|
||||
pythonInstallationPath: string;
|
||||
executeBufferedCommand(command: string): Promise<string>;
|
||||
executeStreamedCommand(command: string): Promise<void>;
|
||||
installPythonPackage(backgroundOperation: azdata.BackgroundOperation, usingExistingPython: boolean, pythonInstallationPath: string, outputChannel: vscode.OutputChannel): Promise<void>;
|
||||
}
|
||||
|
||||
|
||||
export interface IPackageDetails {
|
||||
name: string;
|
||||
version: string;
|
||||
}
|
||||
|
||||
export interface IPackageTarget {
|
||||
location: string;
|
||||
packageType: string;
|
||||
}
|
||||
|
||||
export interface IPackageOverview {
|
||||
name: string;
|
||||
versions: string[];
|
||||
summary: string;
|
||||
}
|
||||
|
||||
export interface IPackageLocation {
|
||||
name: string;
|
||||
displayName: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Package manage provider interface
|
||||
*/
|
||||
export interface IPackageManageProvider {
|
||||
/**
|
||||
* Provider id
|
||||
*/
|
||||
providerId: string;
|
||||
|
||||
/**
|
||||
* package target
|
||||
*/
|
||||
packageTarget: IPackageTarget;
|
||||
|
||||
/**
|
||||
* Returns list of installed packages
|
||||
*/
|
||||
listPackages(location?: string): Promise<IPackageDetails[]>;
|
||||
|
||||
/**
|
||||
* Installs give packages
|
||||
* @param package Packages to install
|
||||
* @param useMinVersion if true, minimal version will be used
|
||||
*/
|
||||
installPackages(package: IPackageDetails[], useMinVersion: boolean, location?: string): Promise<void>;
|
||||
|
||||
/**
|
||||
* Uninstalls given packages
|
||||
* @param package package to uninstall
|
||||
*/
|
||||
uninstallPackages(package: IPackageDetails[], location?: string): Promise<void>;
|
||||
|
||||
/**
|
||||
* Returns true if the provider can be used in current context
|
||||
*/
|
||||
canUseProvider(): Promise<boolean>;
|
||||
|
||||
/**
|
||||
* Returns location title
|
||||
*/
|
||||
getLocations(): Promise<IPackageLocation[]>;
|
||||
|
||||
/**
|
||||
* Returns Package Overview
|
||||
* @param packageName package name
|
||||
*/
|
||||
getPackageOverview(packageName: string): Promise<IPackageOverview>;
|
||||
}
|
||||
9
extensions/machine-learning/src/typings/ref.d.ts
vendored
Normal file
9
extensions/machine-learning/src/typings/ref.d.ts
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
/// <reference path='../../../../src/vs/vscode.d.ts'/>
|
||||
/// <reference path='../../../../src/sql/azdata.d.ts'/>
|
||||
/// <reference path='../../../../src/sql/azdata.proposed.d.ts'/>
|
||||
/// <reference types='@types/node'/>
|
||||
54
extensions/machine-learning/src/views/controllerBase.ts
Normal file
54
extensions/machine-learning/src/views/controllerBase.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
import { ViewBase, LocalPathsEventName } from './viewBase';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
|
||||
/**
|
||||
* Base classes for UI controllers
|
||||
*/
|
||||
export abstract class ControllerBase {
|
||||
|
||||
/**
|
||||
* creates new instance
|
||||
*/
|
||||
constructor(protected _apiWrapper: ApiWrapper) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes an action and sends back callback event to the view
|
||||
*/
|
||||
public async executeAction<T extends ViewBase>(dialog: T, eventName: string, func: (...args: any[]) => Promise<any>, ...args: any[]): Promise<void> {
|
||||
const callbackEvent = ViewBase.getCallbackEventName(eventName);
|
||||
try {
|
||||
let result = await func(...args);
|
||||
dialog.sendCallbackRequest(callbackEvent, { data: result });
|
||||
|
||||
} catch (error) {
|
||||
dialog.sendCallbackRequest(callbackEvent, { error: error });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Register common events for views
|
||||
* @param view view
|
||||
*/
|
||||
public registerEvents(view: ViewBase): void {
|
||||
view.on(LocalPathsEventName, async (args) => {
|
||||
await this.executeAction(view, LocalPathsEventName, this.getLocalPaths, this._apiWrapper, args);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns local file path picked by the user
|
||||
* @param apiWrapper apiWrapper
|
||||
*/
|
||||
public async getLocalPaths(apiWrapper: ApiWrapper, options: vscode.OpenDialogOptions): Promise<string[]> {
|
||||
let result = await apiWrapper.showOpenDialog(options);
|
||||
return result ? result?.map(x => x.fsPath) : [];
|
||||
}
|
||||
}
|
||||
41
extensions/machine-learning/src/views/dialogView.ts
Normal file
41
extensions/machine-learning/src/views/dialogView.ts
Normal file
@@ -0,0 +1,41 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
import * as azdata from 'azdata';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { MainViewBase } from './mainViewBase';
|
||||
import { IPageView } from './interfaces';
|
||||
|
||||
/**
|
||||
* Dialog view to create and manage a dialog
|
||||
*/
|
||||
export class DialogView extends MainViewBase {
|
||||
|
||||
private _dialog: azdata.window.Dialog | undefined;
|
||||
|
||||
/**
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
|
||||
private createDialogPage(title: string, componentView: IPageView): azdata.window.DialogTab {
|
||||
let viewPanel = this._apiWrapper.createTab(title);
|
||||
this.addPage(componentView);
|
||||
this.registerContent(viewPanel, componentView);
|
||||
return viewPanel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new dialog
|
||||
* @param title title
|
||||
* @param pages pages
|
||||
*/
|
||||
public createDialog(title: string, pages: IPageView[]): azdata.window.Dialog {
|
||||
this._dialog = this._apiWrapper.createModelViewDialog(title);
|
||||
this._dialog.content = pages.map(x => this.createDialogPage(x.title || '', x));
|
||||
return this._dialog;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as constants from '../../common/constants';
|
||||
import { LanguageViewBase, LanguageUpdateModel } from './languageViewBase';
|
||||
import { LanguageContentView } from './languageContentView';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export class AddEditLanguageTab extends LanguageViewBase {
|
||||
private _dialogTab: azdata.window.DialogTab;
|
||||
public languageName: azdata.TextComponent | undefined;
|
||||
private _editMode: boolean = false;
|
||||
public saveButton: azdata.ButtonComponent | undefined;
|
||||
public languageView: LanguageContentView | undefined;
|
||||
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
parent: LanguageViewBase,
|
||||
private _languageUpdateModel: LanguageUpdateModel) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._editMode = !this._languageUpdateModel.newLang;
|
||||
this._dialogTab = apiWrapper.createTab(constants.extLangNewLanguageTabTitle);
|
||||
this._dialogTab.registerContent(async view => {
|
||||
let language = this._languageUpdateModel.language;
|
||||
let content = this._languageUpdateModel.content;
|
||||
this.languageName = view.modelBuilder.inputBox().withProperties({
|
||||
value: language.name,
|
||||
width: '150px',
|
||||
enabled: !this._editMode
|
||||
}).withValidation(component => component.value !== '').component();
|
||||
|
||||
let formBuilder = view.modelBuilder.formContainer();
|
||||
formBuilder.addFormItem({
|
||||
component: this.languageName,
|
||||
title: constants.extLangLanguageName,
|
||||
required: true
|
||||
});
|
||||
|
||||
this.languageView = new LanguageContentView(this._apiWrapper, this, view.modelBuilder, formBuilder, content);
|
||||
|
||||
if (!this._editMode) {
|
||||
this.saveButton = view.modelBuilder.button().withProperties({
|
||||
label: constants.extLangInstallButtonText,
|
||||
width: '100px'
|
||||
}).component();
|
||||
this.saveButton.onDidClick(async () => {
|
||||
try {
|
||||
await this.updateLanguage(this.updatedData);
|
||||
} catch (err) {
|
||||
this.showErrorMessage(constants.extLangInstallFailedError, err);
|
||||
}
|
||||
});
|
||||
|
||||
formBuilder.addFormItem({
|
||||
component: this.saveButton,
|
||||
title: ''
|
||||
});
|
||||
}
|
||||
|
||||
await view.initializeModel(formBuilder.component());
|
||||
await this.reset();
|
||||
});
|
||||
}
|
||||
|
||||
public get updatedData(): LanguageUpdateModel {
|
||||
return {
|
||||
language: {
|
||||
name: this.languageName?.value || '',
|
||||
contents: this._languageUpdateModel.language.contents
|
||||
},
|
||||
content: this.languageView?.updatedContent || this._languageUpdateModel.content,
|
||||
newLang: this._languageUpdateModel.newLang
|
||||
};
|
||||
}
|
||||
|
||||
public get tab(): azdata.window.DialogTab {
|
||||
return this._dialogTab;
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
if (this.languageName) {
|
||||
this.languageName.value = this._languageUpdateModel.language.name;
|
||||
}
|
||||
this.languageView?.reset();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import * as constants from '../../common/constants';
|
||||
import { LanguageViewBase } from './languageViewBase';
|
||||
import { LanguagesTable } from './languagesTable';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export class CurrentLanguagesTab extends LanguageViewBase {
|
||||
|
||||
private _installedLangsTab: azdata.window.DialogTab;
|
||||
|
||||
private _locationComponent: azdata.TextComponent | undefined;
|
||||
private _installLanguagesTable: azdata.DeclarativeTableComponent | undefined;
|
||||
private _languageTable: LanguagesTable | undefined;
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: LanguageViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._installedLangsTab = this._apiWrapper.createTab(constants.extLangInstallTabTitle);
|
||||
|
||||
this._installedLangsTab.registerContent(async view => {
|
||||
|
||||
// TODO: only supporting single location for now. We should add a drop down for multi locations mode
|
||||
//
|
||||
let locationTitle = await this.getServerTitle();
|
||||
this._locationComponent = view.modelBuilder.text().withProperties({
|
||||
value: locationTitle
|
||||
}).component();
|
||||
|
||||
this._languageTable = new LanguagesTable(apiWrapper, view.modelBuilder, this);
|
||||
this._installLanguagesTable = this._languageTable.table;
|
||||
|
||||
let formModel = view.modelBuilder.formContainer()
|
||||
.withFormItems([{
|
||||
component: this._locationComponent,
|
||||
title: constants.extLangTarget
|
||||
}, {
|
||||
component: this._installLanguagesTable,
|
||||
title: ''
|
||||
}]).component();
|
||||
|
||||
this._loader = view.modelBuilder.loadingComponent()
|
||||
.withItem(formModel)
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
|
||||
await view.initializeModel(this._loader);
|
||||
await this.reset();
|
||||
});
|
||||
}
|
||||
|
||||
public get tab(): azdata.window.DialogTab {
|
||||
return this._installedLangsTab;
|
||||
}
|
||||
|
||||
private async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
}
|
||||
}
|
||||
|
||||
private async onLoaded(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: false });
|
||||
}
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
await this.onLoading();
|
||||
|
||||
try {
|
||||
await this._languageTable?.reset();
|
||||
} catch (err) {
|
||||
this.showErrorMessage(constants.getErrorMessage(err));
|
||||
} finally {
|
||||
await this.onLoaded();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export class FileBrowserDialog {
|
||||
|
||||
private _selectedPathTextBox: azdata.InputBoxComponent | undefined;
|
||||
private _fileBrowserDialog: azdata.window.Dialog | undefined;
|
||||
private _fileBrowserTree: azdata.FileBrowserTreeComponent | undefined;
|
||||
|
||||
private _onPathSelected: vscode.EventEmitter<string> = new vscode.EventEmitter<string>();
|
||||
public readonly onPathSelected: vscode.Event<string> = this._onPathSelected.event;
|
||||
|
||||
constructor(private _apiWrapper: ApiWrapper, private ownerUri: string) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a dialog to browse server files and folders.
|
||||
*/
|
||||
public showDialog(): void {
|
||||
let fileBrowserTitle = '';
|
||||
this._fileBrowserDialog = this._apiWrapper.createModelViewDialog(fileBrowserTitle);
|
||||
let fileBrowserTab = this._apiWrapper.createTab(constants.extLangFileBrowserTabTitle);
|
||||
this._fileBrowserDialog.content = [fileBrowserTab];
|
||||
fileBrowserTab.registerContent(async (view) => {
|
||||
this._fileBrowserTree = view.modelBuilder.fileBrowserTree()
|
||||
.withProperties({ ownerUri: this.ownerUri, width: 420, height: 700 })
|
||||
.component();
|
||||
this._selectedPathTextBox = view.modelBuilder.inputBox()
|
||||
.withProperties({ inputType: 'text' })
|
||||
.component();
|
||||
this._fileBrowserTree.onDidChange((args) => {
|
||||
if (this._selectedPathTextBox) {
|
||||
this._selectedPathTextBox.value = args.fullPath;
|
||||
}
|
||||
});
|
||||
|
||||
let fileBrowserContainer = view.modelBuilder.formContainer()
|
||||
.withFormItems([{
|
||||
component: this._fileBrowserTree,
|
||||
title: ''
|
||||
}, {
|
||||
component: this._selectedPathTextBox,
|
||||
title: constants.extLangSelectedPath
|
||||
}
|
||||
]).component();
|
||||
view.initializeModel(fileBrowserContainer);
|
||||
});
|
||||
this._fileBrowserDialog.okButton.onClick(() => {
|
||||
if (this._selectedPathTextBox) {
|
||||
let selectedPath = this._selectedPathTextBox.value || '';
|
||||
this._onPathSelected.fire(selectedPath);
|
||||
}
|
||||
});
|
||||
|
||||
this._fileBrowserDialog.cancelButton.onClick(() => {
|
||||
this._onPathSelected.fire('');
|
||||
});
|
||||
this._fileBrowserDialog.okButton.label = constants.extLangOkButtonText;
|
||||
this._fileBrowserDialog.cancelButton.label = constants.extLangCancelButtonText;
|
||||
this._apiWrapper.openDialog(this._fileBrowserDialog);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as mssql from '../../../../mssql';
|
||||
import { LanguageViewBase } from './languageViewBase';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export class LanguageContentView extends LanguageViewBase {
|
||||
|
||||
private _serverPath: azdata.RadioButtonComponent;
|
||||
private _localPath: azdata.RadioButtonComponent;
|
||||
public extensionFile: azdata.TextComponent;
|
||||
public extensionFileName: azdata.TextComponent;
|
||||
public envVariables: azdata.TextComponent;
|
||||
public parameters: azdata.TextComponent;
|
||||
private _isLocalPath: boolean = true;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
parent: LanguageViewBase,
|
||||
private _modelBuilder: azdata.ModelBuilder,
|
||||
private _formBuilder: azdata.FormBuilder,
|
||||
private _languageContent: mssql.ExternalLanguageContent | undefined,
|
||||
) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._localPath = this._modelBuilder.radioButton()
|
||||
.withProperties({
|
||||
value: 'local',
|
||||
name: 'extensionLocation',
|
||||
label: constants.extLangLocal,
|
||||
checked: true
|
||||
}).component();
|
||||
|
||||
this._serverPath = this._modelBuilder.radioButton()
|
||||
.withProperties({
|
||||
value: 'server',
|
||||
name: 'extensionLocation',
|
||||
label: this.getServerTitle(),
|
||||
}).component();
|
||||
|
||||
this._localPath.onDidClick(() => {
|
||||
this._isLocalPath = true;
|
||||
});
|
||||
this._serverPath.onDidClick(() => {
|
||||
this._isLocalPath = false;
|
||||
});
|
||||
|
||||
|
||||
let flexRadioButtonsModel = this._modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between'
|
||||
//width: parent.componentMaxLength
|
||||
}).withItems([
|
||||
this._localPath, this._serverPath]
|
||||
).component();
|
||||
|
||||
this.extensionFile = this._modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: parent.componentMaxLength - parent.browseButtonMaxLength - parent.spaceBetweenComponentsLength
|
||||
}).component();
|
||||
let fileBrowser = this._modelBuilder.button().withProperties({
|
||||
label: '...',
|
||||
width: parent.browseButtonMaxLength,
|
||||
CSSStyles: {
|
||||
'text-align': 'end'
|
||||
}
|
||||
}).component();
|
||||
|
||||
let flexFilePathModel = this._modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between'
|
||||
}).withItems([
|
||||
this.extensionFile, fileBrowser]
|
||||
).component();
|
||||
this.filePathSelected(args => {
|
||||
this.extensionFile.value = args.filePath;
|
||||
});
|
||||
fileBrowser.onDidClick(async () => {
|
||||
this.onOpenFileBrowser({ filePath: '', target: this._isLocalPath ? constants.localhost : this.connectionUrl });
|
||||
});
|
||||
|
||||
this.extensionFileName = this._modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: parent.componentMaxLength
|
||||
}).component();
|
||||
|
||||
this.envVariables = this._modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: parent.componentMaxLength
|
||||
}).component();
|
||||
this.parameters = this._modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: parent.componentMaxLength
|
||||
}).component();
|
||||
|
||||
this.load();
|
||||
|
||||
this._formBuilder.addFormItems([{
|
||||
component: flexRadioButtonsModel,
|
||||
title: constants.extLangExtensionFileLocation
|
||||
}, {
|
||||
component: flexFilePathModel,
|
||||
title: constants.extLangExtensionFilePath,
|
||||
required: true
|
||||
}, {
|
||||
component: this.extensionFileName,
|
||||
title: constants.extLangExtensionFileName,
|
||||
required: true
|
||||
}, {
|
||||
component: this.envVariables,
|
||||
title: constants.extLangEnvVariables
|
||||
}, {
|
||||
component: this.parameters,
|
||||
title: constants.extLangParameters
|
||||
}]);
|
||||
}
|
||||
|
||||
private load() {
|
||||
if (this._languageContent) {
|
||||
this._isLocalPath = this._languageContent.isLocalFile;
|
||||
this._localPath.checked = this._isLocalPath;
|
||||
this._serverPath.checked = !this._isLocalPath;
|
||||
this.extensionFile.value = this._languageContent.pathToExtension;
|
||||
this.extensionFileName.value = this._languageContent.extensionFileName;
|
||||
this.envVariables.value = this._languageContent.environmentVariables;
|
||||
this.parameters.value = this._languageContent.parameters;
|
||||
}
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
this._isLocalPath = true;
|
||||
this._localPath.checked = this._isLocalPath;
|
||||
this._serverPath.checked = !this._isLocalPath;
|
||||
this.load();
|
||||
}
|
||||
|
||||
public get updatedContent(): mssql.ExternalLanguageContent {
|
||||
return {
|
||||
pathToExtension: this.extensionFile.value || '',
|
||||
extensionFileName: this.extensionFileName.value || '',
|
||||
parameters: this.parameters.value || '',
|
||||
environmentVariables: this.envVariables.value || '',
|
||||
isLocalFile: this._isLocalPath || false,
|
||||
platform: this._languageContent?.platform
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as mssql from '../../../../mssql';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { LanguageService } from '../../externalLanguage/languageService';
|
||||
import { LanguagesDialog } from './languagesDialog';
|
||||
import { LanguageEditDialog } from './languageEditDialog';
|
||||
import { FileBrowserDialog } from './fileBrowserDialog';
|
||||
import { LanguageViewBase, LanguageUpdateModel } from './languageViewBase';
|
||||
import * as constants from '../../common/constants';
|
||||
|
||||
export class LanguageController {
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor(
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _root: string,
|
||||
private _service: LanguageService) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the manage language dialog and connects events to the model
|
||||
*/
|
||||
public async manageLanguages(): Promise<LanguagesDialog> {
|
||||
|
||||
let dialog = new LanguagesDialog(this._apiWrapper, this._root);
|
||||
|
||||
// Load current connection
|
||||
//
|
||||
await this._service.load();
|
||||
dialog.connection = this._service.connection;
|
||||
dialog.connectionUrl = this._service.connectionUrl;
|
||||
|
||||
// Handle dialog events and connect to model
|
||||
//
|
||||
dialog.onEdit(model => {
|
||||
this.editLanguage(dialog, model);
|
||||
});
|
||||
dialog.onDelete(async deleteModel => {
|
||||
try {
|
||||
await this.executeAction(dialog, this.deleteLanguage, this._service, deleteModel);
|
||||
dialog.onUpdatedLanguage(deleteModel);
|
||||
} catch (err) {
|
||||
dialog.onActionFailed(err);
|
||||
}
|
||||
});
|
||||
|
||||
dialog.onUpdate(async updateModel => {
|
||||
try {
|
||||
await this.executeAction(dialog, this.updateLanguage, this._service, updateModel);
|
||||
dialog.onUpdatedLanguage(updateModel);
|
||||
} catch (err) {
|
||||
dialog.onActionFailed(err);
|
||||
}
|
||||
});
|
||||
|
||||
dialog.onList(async () => {
|
||||
try {
|
||||
let result = await this.listLanguages(this._service);
|
||||
dialog.onListLanguageLoaded(result);
|
||||
} catch (err) {
|
||||
dialog.onActionFailed(err);
|
||||
}
|
||||
});
|
||||
this.onSelectFile(dialog);
|
||||
|
||||
// Open dialog
|
||||
//
|
||||
dialog.showDialog();
|
||||
return dialog;
|
||||
}
|
||||
|
||||
public async executeAction<T>(dialog: LanguageViewBase, func: (...args: any[]) => Promise<T>, ...args: any[]): Promise<T> {
|
||||
let result = await func(...args);
|
||||
await dialog.reset();
|
||||
return result;
|
||||
}
|
||||
|
||||
public editLanguage(parent: LanguageViewBase, languageUpdateModel: LanguageUpdateModel): void {
|
||||
let editDialog = new LanguageEditDialog(this._apiWrapper, parent, languageUpdateModel);
|
||||
editDialog.showDialog();
|
||||
}
|
||||
|
||||
private onSelectFile(dialog: LanguageViewBase): void {
|
||||
dialog.fileBrowser(async (args) => {
|
||||
let filePath = '';
|
||||
if (args.target === constants.localhost) {
|
||||
filePath = await this.getLocalFilePath();
|
||||
|
||||
} else {
|
||||
filePath = await this.getServerFilePath(args.target);
|
||||
}
|
||||
dialog.onFilePathSelected({ filePath: filePath, target: args.target });
|
||||
});
|
||||
}
|
||||
|
||||
public getServerFilePath(connectionUrl: string): Promise<string> {
|
||||
return new Promise<string>((resolve) => {
|
||||
let dialog = new FileBrowserDialog(this._apiWrapper, connectionUrl);
|
||||
dialog.onPathSelected((selectedPath) => {
|
||||
resolve(selectedPath);
|
||||
});
|
||||
|
||||
dialog.showDialog();
|
||||
});
|
||||
}
|
||||
|
||||
public async getLocalFilePath(): Promise<string> {
|
||||
let result = await this._apiWrapper.showOpenDialog({
|
||||
canSelectFiles: true,
|
||||
canSelectFolders: false,
|
||||
canSelectMany: false
|
||||
});
|
||||
return result && result.length > 0 ? result[0].fsPath : '';
|
||||
}
|
||||
|
||||
public async deleteLanguage(model: LanguageService, deleteModel: LanguageUpdateModel): Promise<void> {
|
||||
await model.deleteLanguage(deleteModel.language.name);
|
||||
}
|
||||
|
||||
public async listLanguages(model: LanguageService): Promise<mssql.ExternalLanguage[]> {
|
||||
return await model.getLanguageList();
|
||||
}
|
||||
|
||||
public async updateLanguage(model: LanguageService, updateModel: LanguageUpdateModel): Promise<void> {
|
||||
if (!updateModel.language) {
|
||||
return;
|
||||
}
|
||||
let contents: mssql.ExternalLanguageContent[] = [];
|
||||
if (updateModel.language.contents && updateModel.language.contents.length >= 0) {
|
||||
contents = updateModel.language.contents.filter(x => x.platform !== updateModel.content.platform);
|
||||
}
|
||||
contents.push(updateModel.content);
|
||||
|
||||
updateModel.language.contents = contents;
|
||||
await model.updateLanguage(updateModel.language);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as constants from '../../common/constants';
|
||||
import { AddEditLanguageTab } from './addEditLanguageTab';
|
||||
import { LanguageViewBase, LanguageUpdateModel } from './languageViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export class LanguageEditDialog extends LanguageViewBase {
|
||||
|
||||
public addNewLanguageTab: AddEditLanguageTab | undefined;
|
||||
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
parent: LanguageViewBase,
|
||||
private _languageUpdateModel: LanguageUpdateModel) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a dialog to edit a language or a content of a language
|
||||
*/
|
||||
public showDialog(): void {
|
||||
this._dialog = this._apiWrapper.createModelViewDialog(constants.extLangDialogTitle);
|
||||
|
||||
this.addNewLanguageTab = new AddEditLanguageTab(this._apiWrapper, this, this._languageUpdateModel);
|
||||
|
||||
this._dialog.cancelButton.label = constants.extLangCancelButtonText;
|
||||
this._dialog.okButton.label = constants.extLangSaveButtonText;
|
||||
|
||||
this.dialog?.registerCloseValidator(async (): Promise<boolean> => {
|
||||
return await this.onSave();
|
||||
});
|
||||
|
||||
this._dialog.content = [this.addNewLanguageTab.tab];
|
||||
this._apiWrapper.openDialog(this._dialog);
|
||||
}
|
||||
|
||||
public async onSave(): Promise<boolean> {
|
||||
if (this.addNewLanguageTab) {
|
||||
try {
|
||||
await this.updateLanguage(this.addNewLanguageTab.updatedData);
|
||||
return true;
|
||||
} catch (err) {
|
||||
this.showErrorMessage(constants.extLangUpdateFailedError, err);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
/**
|
||||
* Resets the tabs for given provider Id
|
||||
*/
|
||||
public async reset(): Promise<void> {
|
||||
await this.addNewLanguageTab?.reset();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as mssql from '../../../../mssql';
|
||||
import * as path from 'path';
|
||||
|
||||
export interface LanguageUpdateModel {
|
||||
language: mssql.ExternalLanguage,
|
||||
content: mssql.ExternalLanguageContent,
|
||||
newLang: boolean
|
||||
}
|
||||
|
||||
export interface FileBrowseEventArgs {
|
||||
filePath: string,
|
||||
target: string
|
||||
}
|
||||
|
||||
export abstract class LanguageViewBase {
|
||||
protected _dialog: azdata.window.Dialog | undefined;
|
||||
public connection: azdata.connection.ConnectionProfile | undefined;
|
||||
public connectionUrl: string = '';
|
||||
|
||||
// Events
|
||||
//
|
||||
protected _onEdit: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
|
||||
public readonly onEdit: vscode.Event<LanguageUpdateModel> = this._onEdit.event;
|
||||
|
||||
protected _onUpdate: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
|
||||
public readonly onUpdate: vscode.Event<LanguageUpdateModel> = this._onUpdate.event;
|
||||
|
||||
protected _onDelete: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
|
||||
public readonly onDelete: vscode.Event<LanguageUpdateModel> = this._onDelete.event;
|
||||
|
||||
protected _fileBrowser: vscode.EventEmitter<FileBrowseEventArgs> = new vscode.EventEmitter<FileBrowseEventArgs>();
|
||||
public readonly fileBrowser: vscode.Event<FileBrowseEventArgs> = this._fileBrowser.event;
|
||||
|
||||
protected _filePathSelected: vscode.EventEmitter<FileBrowseEventArgs> = new vscode.EventEmitter<FileBrowseEventArgs>();
|
||||
public readonly filePathSelected: vscode.Event<FileBrowseEventArgs> = this._filePathSelected.event;
|
||||
|
||||
protected _onUpdated: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
|
||||
public readonly onUpdated: vscode.Event<LanguageUpdateModel> = this._onUpdated.event;
|
||||
|
||||
protected _onList: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onList: vscode.Event<void> = this._onList.event;
|
||||
|
||||
protected _onListLoaded: vscode.EventEmitter<mssql.ExternalLanguage[]> = new vscode.EventEmitter<mssql.ExternalLanguage[]>();
|
||||
public readonly onListLoaded: vscode.Event<mssql.ExternalLanguage[]> = this._onListLoaded.event;
|
||||
|
||||
protected _onFailed: vscode.EventEmitter<any> = new vscode.EventEmitter<any>();
|
||||
public readonly onFailed: vscode.Event<any> = this._onFailed.event;
|
||||
|
||||
public componentMaxLength = 350;
|
||||
public browseButtonMaxLength = 20;
|
||||
public spaceBetweenComponentsLength = 10;
|
||||
|
||||
constructor(protected _apiWrapper: ApiWrapper, protected _root?: string, protected _parent?: LanguageViewBase,) {
|
||||
if (this._parent) {
|
||||
if (!this._root) {
|
||||
this._root = this._parent.root;
|
||||
}
|
||||
this.connection = this._parent.connection;
|
||||
this.connectionUrl = this._parent.connectionUrl;
|
||||
}
|
||||
this.registerEvents();
|
||||
}
|
||||
|
||||
private registerEvents() {
|
||||
if (this._parent) {
|
||||
this._dialog = this._parent.dialog;
|
||||
this.fileBrowser(url => {
|
||||
this._parent?.onOpenFileBrowser(url);
|
||||
});
|
||||
this.onUpdate(model => {
|
||||
this._parent?.onUpdateLanguage(model);
|
||||
});
|
||||
this.onEdit(model => {
|
||||
this._parent?.onEditLanguage(model);
|
||||
});
|
||||
this.onDelete(model => {
|
||||
this._parent?.onDeleteLanguage(model);
|
||||
});
|
||||
this.onList(() => {
|
||||
this._parent?.onListLanguages();
|
||||
});
|
||||
this._parent.filePathSelected(x => {
|
||||
this.onFilePathSelected(x);
|
||||
});
|
||||
this._parent.onUpdated(x => {
|
||||
this.onUpdatedLanguage(x);
|
||||
});
|
||||
this._parent.onFailed(x => {
|
||||
this.onActionFailed(x);
|
||||
});
|
||||
this._parent.onListLoaded(x => {
|
||||
this.onListLanguageLoaded(x);
|
||||
});
|
||||
}
|
||||
}
|
||||
public async getLocationTitle(): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : constants.extLangLocal}`;
|
||||
}
|
||||
return constants.noConnectionError;
|
||||
}
|
||||
|
||||
public getServerTitle(): string {
|
||||
if (this.connection) {
|
||||
return this.connection.serverName;
|
||||
}
|
||||
return constants.noConnectionError;
|
||||
}
|
||||
|
||||
private async getCurrentConnectionUrl(): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
return await this._apiWrapper.getUriForConnection(connection.connectionId);
|
||||
}
|
||||
return '';
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
public async loadConnection(): Promise<void> {
|
||||
this.connection = await this.getCurrentConnection();
|
||||
this.connectionUrl = await this.getCurrentConnectionUrl();
|
||||
}
|
||||
|
||||
public updateLanguage(updateModel: LanguageUpdateModel): Promise<void> {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
this.onUpdateLanguage(updateModel);
|
||||
this.onUpdated(() => {
|
||||
resolve();
|
||||
});
|
||||
this.onFailed(err => {
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
public deleteLanguage(model: LanguageUpdateModel): Promise<void> {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
this.onDeleteLanguage(model);
|
||||
this.onUpdated(() => {
|
||||
resolve();
|
||||
});
|
||||
this.onFailed(err => {
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
public listLanguages(): Promise<mssql.ExternalLanguage[]> {
|
||||
return new Promise<mssql.ExternalLanguage[]>((resolve, reject) => {
|
||||
this.onListLanguages();
|
||||
this.onListLoaded(list => {
|
||||
resolve(list);
|
||||
});
|
||||
this.onFailed(err => {
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Dialog model instance
|
||||
*/
|
||||
public get dialog(): azdata.window.Dialog | undefined {
|
||||
return this._dialog;
|
||||
}
|
||||
|
||||
public set dialog(value: azdata.window.Dialog | undefined) {
|
||||
this._dialog = value;
|
||||
}
|
||||
|
||||
public showInfoMessage(message: string): void {
|
||||
this.showMessage(message, azdata.window.MessageLevel.Information);
|
||||
}
|
||||
|
||||
public showErrorMessage(message: string, error?: any): void {
|
||||
this.showMessage(`${message} ${constants.getErrorMessage(error)}`, azdata.window.MessageLevel.Error);
|
||||
}
|
||||
|
||||
public onUpdateLanguage(model: LanguageUpdateModel): void {
|
||||
this._onUpdate.fire(model);
|
||||
}
|
||||
|
||||
public onUpdatedLanguage(model: LanguageUpdateModel): void {
|
||||
this._onUpdated.fire(model);
|
||||
}
|
||||
|
||||
public onActionFailed(error: any): void {
|
||||
this._onFailed.fire(error);
|
||||
}
|
||||
|
||||
public onListLanguageLoaded(list: mssql.ExternalLanguage[]): void {
|
||||
this._onListLoaded.fire(list);
|
||||
}
|
||||
|
||||
public onEditLanguage(model: LanguageUpdateModel): void {
|
||||
this._onEdit.fire(model);
|
||||
}
|
||||
|
||||
public onDeleteLanguage(model: LanguageUpdateModel): void {
|
||||
this._onDelete.fire(model);
|
||||
}
|
||||
|
||||
public onListLanguages(): void {
|
||||
this._onList.fire();
|
||||
}
|
||||
|
||||
public onOpenFileBrowser(fileBrowseArgs: FileBrowseEventArgs): void {
|
||||
this._fileBrowser.fire(fileBrowseArgs);
|
||||
}
|
||||
|
||||
public onFilePathSelected(fileBrowseArgs: FileBrowseEventArgs): void {
|
||||
this._filePathSelected.fire(fileBrowseArgs);
|
||||
}
|
||||
|
||||
private showMessage(message: string, level: azdata.window.MessageLevel): void {
|
||||
if (this._dialog) {
|
||||
this._dialog.message = {
|
||||
text: message,
|
||||
level: level
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
public get root(): string {
|
||||
return this._root || '';
|
||||
}
|
||||
|
||||
public asAbsolutePath(filePath: string): string {
|
||||
return path.join(this._root || '', filePath);
|
||||
}
|
||||
|
||||
public abstract reset(): Promise<void>;
|
||||
|
||||
public createNewContent(): mssql.ExternalLanguageContent {
|
||||
return {
|
||||
extensionFileName: '',
|
||||
isLocalFile: true,
|
||||
pathToExtension: '',
|
||||
};
|
||||
}
|
||||
|
||||
public createNewLanguage(): mssql.ExternalLanguage {
|
||||
return {
|
||||
name: '',
|
||||
contents: []
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { CurrentLanguagesTab } from './currentLanguagesTab';
|
||||
import { AddEditLanguageTab } from './addEditLanguageTab';
|
||||
import { LanguageViewBase } from './languageViewBase';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export class LanguagesDialog extends LanguageViewBase {
|
||||
|
||||
public currentLanguagesTab: CurrentLanguagesTab | undefined;
|
||||
public addNewLanguageTab: AddEditLanguageTab | undefined;
|
||||
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
root: string) {
|
||||
super(apiWrapper, root);
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a dialog to manage packages used by notebooks.
|
||||
*/
|
||||
public showDialog(): void {
|
||||
this.dialog = this._apiWrapper.createModelViewDialog(constants.extLangDialogTitle);
|
||||
|
||||
this.currentLanguagesTab = new CurrentLanguagesTab(this._apiWrapper, this);
|
||||
|
||||
let languageUpdateModel = {
|
||||
language: this.createNewLanguage(),
|
||||
content: this.createNewContent(),
|
||||
newLang: true
|
||||
};
|
||||
this.addNewLanguageTab = new AddEditLanguageTab(this._apiWrapper, this, languageUpdateModel);
|
||||
|
||||
this.dialog.okButton.hidden = true;
|
||||
this.dialog.cancelButton.label = constants.extLangDoneButtonText;
|
||||
this.dialog.content = [this.currentLanguagesTab.tab, this.addNewLanguageTab.tab];
|
||||
|
||||
this.dialog.registerCloseValidator(() => {
|
||||
return false; // Blocks Enter key from closing dialog.
|
||||
});
|
||||
|
||||
this._apiWrapper.openDialog(this.dialog);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the tabs for given provider Id
|
||||
*/
|
||||
public async reset(): Promise<void> {
|
||||
await this.currentLanguagesTab?.reset();
|
||||
await this.addNewLanguageTab?.reset();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as constants from '../../common/constants';
|
||||
import * as mssql from '../../../../mssql';
|
||||
import { LanguageViewBase } from './languageViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
|
||||
export class LanguagesTable extends LanguageViewBase {
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: LanguageViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._table = _modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: [
|
||||
{ // Name
|
||||
displayName: constants.extLangLanguageName,
|
||||
ariaLabel: constants.extLangLanguageName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 100,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Platform
|
||||
displayName: constants.extLangLanguagePlatform,
|
||||
ariaLabel: constants.extLangLanguagePlatform,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Created Date
|
||||
displayName: constants.extLangLanguageCreatedDate,
|
||||
ariaLabel: constants.extLangLanguageCreatedDate,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
],
|
||||
data: [],
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
}
|
||||
|
||||
public get table(): azdata.DeclarativeTableComponent {
|
||||
return this._table;
|
||||
}
|
||||
|
||||
public async loadData(): Promise<void> {
|
||||
let languages: mssql.ExternalLanguage[] | undefined;
|
||||
|
||||
languages = await this.listLanguages();
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (languages) {
|
||||
|
||||
languages.forEach(language => {
|
||||
if (!language.contents || language.contents.length === 0) {
|
||||
language.contents.push(this.createNewContent());
|
||||
}
|
||||
|
||||
tableData = tableData.concat(language.contents.map(content => this.createTableRow(language, content)));
|
||||
});
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
|
||||
private createTableRow(language: mssql.ExternalLanguage, content: mssql.ExternalLanguageContent): any[] {
|
||||
if (this._modelBuilder) {
|
||||
let dropLanguageButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
title: constants.deleteTitle,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/delete.svg')
|
||||
},
|
||||
width: 15,
|
||||
height: 15
|
||||
}).component();
|
||||
dropLanguageButton.onDidClick(async () => {
|
||||
await this.deleteLanguage({
|
||||
language: language,
|
||||
content: content,
|
||||
newLang: false
|
||||
});
|
||||
});
|
||||
|
||||
let editLanguageButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
title: constants.deleteTitle,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/edit.svg')
|
||||
},
|
||||
width: 15,
|
||||
height: 15
|
||||
}).component();
|
||||
editLanguageButton.onDidClick(() => {
|
||||
this.onEditLanguage({
|
||||
language: language,
|
||||
content: content,
|
||||
newLang: false
|
||||
});
|
||||
});
|
||||
return [language.name, content.platform, language.createdDate, dropLanguageButton, editLanguageButton];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
}
|
||||
43
extensions/machine-learning/src/views/interfaces.ts
Normal file
43
extensions/machine-learning/src/views/interfaces.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
import * as azdata from 'azdata';
|
||||
import { azureResource } from '../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { WorkspaceModel } from '../modelManagement/interfaces';
|
||||
|
||||
export interface IDataComponent<T> {
|
||||
data: T | undefined;
|
||||
}
|
||||
|
||||
export interface IPageView {
|
||||
registerComponent: (modelBuilder: azdata.ModelBuilder) => azdata.Component;
|
||||
component: azdata.Component | undefined;
|
||||
onEnter?: () => Promise<void>;
|
||||
onLeave?: () => Promise<void>;
|
||||
validate?: () => Promise<boolean>;
|
||||
refresh: () => Promise<void>;
|
||||
disposePage?: () => Promise<void>;
|
||||
viewPanel: azdata.window.ModelViewPanel | undefined;
|
||||
title: string;
|
||||
}
|
||||
|
||||
export interface AzureWorkspaceResource {
|
||||
account?: azdata.Account,
|
||||
subscription?: azureResource.AzureResourceSubscription,
|
||||
group?: azureResource.AzureResource,
|
||||
workspace?: Workspace
|
||||
}
|
||||
|
||||
export interface AzureModelResource extends AzureWorkspaceResource {
|
||||
model?: WorkspaceModel;
|
||||
}
|
||||
|
||||
export interface IComponentSettings {
|
||||
multiSelect?: boolean;
|
||||
editable?: boolean;
|
||||
selectable?: boolean;
|
||||
}
|
||||
|
||||
|
||||
55
extensions/machine-learning/src/views/mainViewBase.ts
Normal file
55
extensions/machine-learning/src/views/mainViewBase.ts
Normal file
@@ -0,0 +1,55 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
import * as azdata from 'azdata';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { IPageView } from './interfaces';
|
||||
|
||||
/**
|
||||
* Base class for dialog and wizard
|
||||
*/
|
||||
export class MainViewBase {
|
||||
|
||||
protected _pages: IPageView[] = [];
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
constructor(protected _apiWrapper: ApiWrapper) {
|
||||
}
|
||||
|
||||
protected registerContent(viewPanel: azdata.window.DialogTab | azdata.window.WizardPage, componentView: IPageView) {
|
||||
viewPanel.registerContent(async view => {
|
||||
if (componentView) {
|
||||
let component = componentView.registerComponent(view.modelBuilder);
|
||||
await view.initializeModel(component);
|
||||
await componentView.refresh();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected addPage(page: IPageView, index?: number): void {
|
||||
if (index) {
|
||||
this._pages[index] = page;
|
||||
} else {
|
||||
this._pages.push(page);
|
||||
}
|
||||
}
|
||||
|
||||
public async disposePages(): Promise<void> {
|
||||
if (this._pages) {
|
||||
await Promise.all(this._pages.map(async (p) => {
|
||||
if (p.disposePage) {
|
||||
await p.disposePage();
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._pages) {
|
||||
await Promise.all(this._pages.map(async (p) => await p.refresh()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
|
||||
import { AzureModelsTable } from './azureModelsTable';
|
||||
import { IDataComponent, AzureModelResource } from '../interfaces';
|
||||
import { ModelArtifact } from './prediction/modelArtifact';
|
||||
import { AzureSignInComponent } from './azureSignInComponent';
|
||||
|
||||
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource[]> {
|
||||
|
||||
public azureModelsTable: AzureModelsTable | undefined;
|
||||
public azureFilterComponent: AzureResourceFilterComponent | undefined;
|
||||
public azureSignInComponent: AzureSignInComponent | undefined;
|
||||
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _downloadedFile: ModelArtifact | undefined;
|
||||
|
||||
/**
|
||||
* Component to render a view to pick an azure model
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this);
|
||||
this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect);
|
||||
this.azureSignInComponent = new AzureSignInComponent(this._apiWrapper, modelBuilder, this);
|
||||
this._loader = modelBuilder.loadingComponent()
|
||||
.withItem(this.azureModelsTable.component)
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
this.azureModelsTable.onModelSelectionChanged(async () => {
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
this._downloadedFile = undefined;
|
||||
});
|
||||
|
||||
this.azureFilterComponent.onWorkspacesSelectedChanged(async () => {
|
||||
await this.onLoading();
|
||||
await this.azureModelsTable?.loadData(this.azureFilterComponent?.data);
|
||||
await this.onLoaded();
|
||||
});
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
component: this.azureFilterComponent.component
|
||||
}, {
|
||||
title: '',
|
||||
component: this._loader
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
this.removeComponents(formBuilder);
|
||||
if (this.azureFilterComponent?.data?.account) {
|
||||
this.addAzureComponents(formBuilder);
|
||||
} else {
|
||||
this.addAzureSignInComponents(formBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
this.removeAzureComponents(formBuilder);
|
||||
this.removeAzureSignInComponents(formBuilder);
|
||||
}
|
||||
|
||||
private addAzureComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.azureFilterComponent && this._loader) {
|
||||
this.azureFilterComponent.addComponents(formBuilder);
|
||||
|
||||
formBuilder.addFormItems([{
|
||||
title: '',
|
||||
component: this._loader
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
private removeAzureComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.azureFilterComponent && this._loader) {
|
||||
this.azureFilterComponent.removeComponents(formBuilder);
|
||||
formBuilder.removeFormItem({
|
||||
title: '',
|
||||
component: this._loader
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private addAzureSignInComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.azureSignInComponent) {
|
||||
this.azureSignInComponent.addComponents(formBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
private removeAzureSignInComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.azureSignInComponent) {
|
||||
this.azureSignInComponent.removeComponents(formBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
private async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
}
|
||||
}
|
||||
|
||||
private async onLoaded(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: false });
|
||||
}
|
||||
}
|
||||
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the data in the components
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
await this.azureFilterComponent?.loadData();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): AzureModelResource[] | undefined {
|
||||
return this.azureModelsTable?.data ? this.azureModelsTable?.data.map(x => Object.assign({}, this.azureFilterComponent?.data, {
|
||||
model: x
|
||||
})) : undefined;
|
||||
}
|
||||
|
||||
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
|
||||
const data = this.data;
|
||||
if (!this._downloadedFile && data && data.length > 0) {
|
||||
this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(data[0]));
|
||||
}
|
||||
return this._downloadedFile;
|
||||
}
|
||||
|
||||
/**
|
||||
* disposes the view
|
||||
*/
|
||||
public async disposeComponent(): Promise<void> {
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
}
|
||||
184
extensions/machine-learning/src/views/models/azureModelsTable.ts
Normal file
184
extensions/machine-learning/src/views/models/azureModelsTable.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as constants from '../../common/constants';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { WorkspaceModel } from '../../modelManagement/interfaces';
|
||||
import { IDataComponent, AzureWorkspaceResource } from '../interfaces';
|
||||
|
||||
/**
|
||||
* View to render azure models in a table
|
||||
*/
|
||||
export class AzureModelsTable extends ModelViewBase implements IDataComponent<WorkspaceModel[]> {
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent;
|
||||
private _selectedModel: WorkspaceModel[] = [];
|
||||
private _models: WorkspaceModel[] | undefined;
|
||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
||||
|
||||
/**
|
||||
* Creates a view to render azure models in a table
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._table = this.registerComponent(this._modelBuilder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: [
|
||||
{ // Name
|
||||
displayName: constants.modelName,
|
||||
ariaLabel: constants.modelName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Created
|
||||
displayName: constants.modelCreated,
|
||||
ariaLabel: constants.modelCreated,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 100,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Version
|
||||
displayName: constants.modelVersion,
|
||||
ariaLabel: constants.modelVersion,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 100,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
],
|
||||
data: [],
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
return this._table;
|
||||
}
|
||||
|
||||
public get component(): azdata.DeclarativeTableComponent {
|
||||
return this._table;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load data in the component
|
||||
* @param workspaceResource Azure workspace
|
||||
*/
|
||||
public async loadData(workspaceResource?: AzureWorkspaceResource | undefined): Promise<void> {
|
||||
|
||||
if (this._table && workspaceResource) {
|
||||
this._models = await this.listAzureModels(workspaceResource);
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (this._models) {
|
||||
tableData = tableData.concat(this._models.map(model => this.createTableRow(model)));
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
this._onModelSelectionChanged.fire();
|
||||
}
|
||||
|
||||
private createTableRow(model: WorkspaceModel): any[] {
|
||||
if (this._modelBuilder) {
|
||||
let selectModelButton: azdata.Component;
|
||||
let onSelectItem = (checked: boolean) => {
|
||||
const foundItem = this._selectedModel.find(x => x === model);
|
||||
if (checked && !foundItem) {
|
||||
this._selectedModel.push(model);
|
||||
} else if (foundItem) {
|
||||
this._selectedModel = this._selectedModel.filter(x => x !== model);
|
||||
}
|
||||
this._onModelSelectionChanged.fire();
|
||||
};
|
||||
if (this._multiSelect) {
|
||||
const checkbox = this._modelBuilder.checkBox().withProperties({
|
||||
name: 'amlModel',
|
||||
value: model.id,
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: false
|
||||
}).component();
|
||||
checkbox.onChanged(() => {
|
||||
onSelectItem(checkbox.checked || false);
|
||||
});
|
||||
selectModelButton = checkbox;
|
||||
} else {
|
||||
const radioButton = this._modelBuilder.radioButton().withProperties({
|
||||
name: 'amlModel',
|
||||
value: model.id,
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: false
|
||||
}).component();
|
||||
radioButton.onDidClick(() => {
|
||||
onSelectItem(radioButton.checked || false);
|
||||
});
|
||||
selectModelButton = radioButton;
|
||||
}
|
||||
|
||||
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): WorkspaceModel[] | undefined {
|
||||
if (this._models && this._selectedModel) {
|
||||
return this._selectedModel;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import * as constants from '../../common/constants';
|
||||
import { AzureWorkspaceResource, IDataComponent } from '../interfaces';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
const componentWidth = 300;
|
||||
export class AzureResourceFilterComponent extends ModelViewBase implements IDataComponent<AzureWorkspaceResource> {
|
||||
|
||||
private _form: azdata.FormContainer;
|
||||
private _accounts: azdata.DropDownComponent;
|
||||
private _subscriptions: azdata.DropDownComponent;
|
||||
private _groups: azdata.DropDownComponent;
|
||||
private _workspaces: azdata.DropDownComponent;
|
||||
private _azureAccounts: azdata.Account[] = [];
|
||||
private _azureSubscriptions: azureResource.AzureResourceSubscription[] = [];
|
||||
private _azureGroups: azureResource.AzureResource[] = [];
|
||||
private _azureWorkspaces: Workspace[] = [];
|
||||
private _onWorkspacesSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onWorkspacesSelectedChanged: vscode.Event<void> = this._onWorkspacesSelectedChanged.event;
|
||||
|
||||
/**
|
||||
* Creates a new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._accounts = this._modelBuilder.dropDown().withProperties({
|
||||
width: componentWidth
|
||||
}).component();
|
||||
this._subscriptions = this._modelBuilder.dropDown().withProperties({
|
||||
width: componentWidth
|
||||
}).component();
|
||||
this._groups = this._modelBuilder.dropDown().withProperties({
|
||||
width: componentWidth
|
||||
}).component();
|
||||
this._workspaces = this._modelBuilder.dropDown().withProperties({
|
||||
width: componentWidth
|
||||
}).component();
|
||||
|
||||
this._accounts.onValueChanged(async () => {
|
||||
await this.onAccountSelected();
|
||||
});
|
||||
|
||||
this._subscriptions.onValueChanged(async () => {
|
||||
await this.onSubscriptionSelected();
|
||||
});
|
||||
this._groups.onValueChanged(async () => {
|
||||
await this.onGroupSelected();
|
||||
});
|
||||
this._workspaces.onValueChanged(async () => {
|
||||
await this.onWorkspaceSelectedChanged();
|
||||
});
|
||||
|
||||
this._form = this._modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._accounts
|
||||
}, {
|
||||
title: constants.azureSubscription,
|
||||
component: this._subscriptions
|
||||
}, {
|
||||
title: constants.azureGroup,
|
||||
component: this._groups
|
||||
}, {
|
||||
title: constants.azureModelWorkspace,
|
||||
component: this._workspaces
|
||||
}]).component();
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._accounts
|
||||
}, {
|
||||
title: constants.azureSubscription,
|
||||
component: this._subscriptions
|
||||
}, {
|
||||
title: constants.azureGroup,
|
||||
component: this._groups
|
||||
}, {
|
||||
title: constants.azureModelWorkspace,
|
||||
component: this._workspaces
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureAccount,
|
||||
component: this._accounts
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureSubscription,
|
||||
component: this._subscriptions
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureGroup,
|
||||
component: this._groups
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureModelWorkspace,
|
||||
component: this._workspaces
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the created component
|
||||
*/
|
||||
public get component(): azdata.Component {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): AzureWorkspaceResource | undefined {
|
||||
return {
|
||||
account: this.account,
|
||||
subscription: this.subscription,
|
||||
group: this.group,
|
||||
workspace: this.workspace
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* loads data in the components
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
this._azureAccounts = await this.listAzureAccounts();
|
||||
if (this._azureAccounts && this._azureAccounts.length > 0) {
|
||||
let values = this._azureAccounts.map(a => { return { displayName: a.displayInfo.displayName, name: a.key.accountId }; });
|
||||
this._accounts.values = values;
|
||||
this._accounts.value = values[0];
|
||||
}
|
||||
await this.onAccountSelected();
|
||||
}
|
||||
|
||||
/**
|
||||
* refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
|
||||
private async onAccountSelected(): Promise<void> {
|
||||
this._azureSubscriptions = await this.listAzureSubscriptions(this.account);
|
||||
if (this._azureSubscriptions && this._azureSubscriptions.length > 0) {
|
||||
let values = this._azureSubscriptions.map(s => { return { displayName: s.name, name: s.id }; });
|
||||
this._subscriptions.values = values;
|
||||
this._subscriptions.value = values[0];
|
||||
}
|
||||
await this.onSubscriptionSelected();
|
||||
}
|
||||
|
||||
private async onSubscriptionSelected(): Promise<void> {
|
||||
this._azureGroups = await this.listAzureGroups(this.account, this.subscription);
|
||||
if (this._azureGroups && this._azureGroups.length > 0) {
|
||||
let values = this._azureGroups.map(s => { return { displayName: s.name, name: s.id }; });
|
||||
this._groups.values = values;
|
||||
this._groups.value = values[0];
|
||||
}
|
||||
await this.onGroupSelected();
|
||||
}
|
||||
|
||||
private async onGroupSelected(): Promise<void> {
|
||||
this._azureWorkspaces = await this.listWorkspaces(this.account, this.subscription, this.group);
|
||||
if (this._azureWorkspaces && this._azureWorkspaces.length > 0) {
|
||||
let values = this._azureWorkspaces.map(s => { return { displayName: s.name || '', name: s.id || '' }; });
|
||||
this._workspaces.values = values;
|
||||
this._workspaces.value = values[0];
|
||||
}
|
||||
this.onWorkspaceSelectedChanged();
|
||||
}
|
||||
|
||||
private onWorkspaceSelectedChanged(): void {
|
||||
this._onWorkspacesSelectedChanged.fire();
|
||||
}
|
||||
|
||||
private get workspace(): Workspace | undefined {
|
||||
return this._azureWorkspaces && this._workspaces.value ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
|
||||
}
|
||||
|
||||
private get account(): azdata.Account | undefined {
|
||||
return this._azureAccounts && this._accounts.value ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
|
||||
}
|
||||
|
||||
private get group(): azureResource.AzureResource | undefined {
|
||||
return this._azureGroups && this._groups.value ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
|
||||
}
|
||||
|
||||
private get subscription(): azureResource.AzureResourceSubscription | undefined {
|
||||
return this._azureSubscriptions && this._subscriptions.value ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, SignInToAzureEventName } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
const componentWidth = 300;
|
||||
export class AzureSignInComponent extends ModelViewBase {
|
||||
|
||||
private _form: azdata.FormContainer;
|
||||
private _signInButton: azdata.ButtonComponent;
|
||||
|
||||
/**
|
||||
* Creates a new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._signInButton = this._modelBuilder.button().withProperties({
|
||||
width: componentWidth,
|
||||
label: constants.azureSignIn,
|
||||
}).component();
|
||||
this._signInButton.onDidClick(() => {
|
||||
this.sendRequest(SignInToAzureEventName);
|
||||
});
|
||||
|
||||
this._form = this._modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._signInButton
|
||||
}]).component();
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._signInButton) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.azureAccount,
|
||||
component: this._signInButton
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._signInButton) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.azureAccount,
|
||||
component: this._signInButton
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the created component
|
||||
*/
|
||||
public get component(): azdata.Component {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
|
||||
import { ModelViewBase } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
|
||||
/**
|
||||
* View to pick local models file
|
||||
*/
|
||||
export class LocalModelsComponent extends ModelViewBase implements IDataComponent<string[]> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _flex: azdata.FlexContainer | undefined;
|
||||
private _localPath: azdata.InputBoxComponent | undefined;
|
||||
private _localBrowse: azdata.ButtonComponent | undefined;
|
||||
|
||||
/**
|
||||
* Creates new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._localPath = modelBuilder.inputBox().withProperties({
|
||||
value: '',
|
||||
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength
|
||||
}).component();
|
||||
this._localBrowse = modelBuilder.button().withProperties({
|
||||
label: constants.browseModels,
|
||||
width: this.browseButtonMaxLength,
|
||||
CSSStyles: {
|
||||
'text-align': 'end'
|
||||
}
|
||||
}).component();
|
||||
this._localBrowse.onDidClick(async () => {
|
||||
|
||||
let options: vscode.OpenDialogOptions = {
|
||||
canSelectFiles: true,
|
||||
canSelectFolders: false,
|
||||
canSelectMany: this._multiSelect,
|
||||
filters: { 'ONNX File': ['onnx'] }
|
||||
};
|
||||
|
||||
const filePaths = await this.getLocalPaths(options);
|
||||
if (this._localPath && filePaths && filePaths.length > 0) {
|
||||
this._localPath.value = this._multiSelect ? filePaths.join(';') : filePaths[0];
|
||||
} else if (this._localPath) {
|
||||
this._localPath.value = '';
|
||||
}
|
||||
});
|
||||
|
||||
this._flex = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between',
|
||||
width: this.componentMaxLength
|
||||
}).withItems([
|
||||
this._localPath, this._localBrowse]
|
||||
).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
component: this._flex
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
formBuilder.addFormItem({
|
||||
title: '',
|
||||
component: this._flex
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flex) {
|
||||
formBuilder.removeFormItem({
|
||||
title: '',
|
||||
component: this._flex
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): string[] {
|
||||
if (this._localPath?.value) {
|
||||
return this._localPath?.value.split(';');
|
||||
} else {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.localModelsTitle;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { CurrentModelsTable } from './currentModelsTable';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { IPageView, IComponentSettings } from '../../interfaces';
|
||||
import { TableSelectionComponent } from '../tableSelectionComponent';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to render current registered models
|
||||
*/
|
||||
export class CurrentModelsComponent extends ModelViewBase implements IPageView {
|
||||
private _tableComponent: azdata.Component | undefined;
|
||||
private _dataTable: CurrentModelsTable | undefined;
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
private _tableSelectionComponent: TableSelectionComponent | undefined;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param apiWrapper Creates new view
|
||||
* @param parent page parent
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
|
||||
this._tableSelectionComponent.registerComponent(modelBuilder);
|
||||
this._tableSelectionComponent.onSelectedChanged(async () => {
|
||||
await this.onTableSelected();
|
||||
});
|
||||
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings);
|
||||
this._dataTable.registerComponent(modelBuilder);
|
||||
this._tableComponent = this._dataTable.component;
|
||||
|
||||
let formModelBuilder = modelBuilder.formContainer();
|
||||
this._tableSelectionComponent.addComponents(formModelBuilder);
|
||||
|
||||
if (this._tableComponent) {
|
||||
formModelBuilder.addFormItem({
|
||||
component: this._tableComponent,
|
||||
title: ''
|
||||
});
|
||||
}
|
||||
|
||||
this._loader = modelBuilder.loadingComponent()
|
||||
.withItem(formModelBuilder.component())
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._tableSelectionComponent && this._dataTable) {
|
||||
this._tableSelectionComponent.addComponents(formBuilder);
|
||||
this._dataTable.addComponents(formBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._tableSelectionComponent && this._dataTable) {
|
||||
this._tableSelectionComponent.removeComponents(formBuilder);
|
||||
this._dataTable.removeComponents(formBuilder);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.onLoading();
|
||||
|
||||
try {
|
||||
if (this._tableSelectionComponent) {
|
||||
this._tableSelectionComponent.refresh();
|
||||
}
|
||||
await this._dataTable?.refresh();
|
||||
} catch (err) {
|
||||
this.showErrorMessage(constants.getErrorMessage(err));
|
||||
} finally {
|
||||
await this.onLoaded();
|
||||
}
|
||||
}
|
||||
|
||||
public get data(): ImportedModel[] | undefined {
|
||||
return this._dataTable?.data;
|
||||
}
|
||||
|
||||
private async onTableSelected(): Promise<void> {
|
||||
if (this._tableSelectionComponent?.data) {
|
||||
this.importTable = this._tableSelectionComponent?.data;
|
||||
await this.storeImportConfigTable();
|
||||
await this._dataTable?.refresh();
|
||||
}
|
||||
}
|
||||
|
||||
public get modelTable(): CurrentModelsTable | undefined {
|
||||
return this._dataTable;
|
||||
}
|
||||
|
||||
/**
|
||||
* disposes the view
|
||||
*/
|
||||
public async disposeComponent(): Promise<void> {
|
||||
if (this._dataTable) {
|
||||
await this._dataTable.disposeComponent();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* returns the title of the page
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.currentModelsTitle;
|
||||
}
|
||||
|
||||
private async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
}
|
||||
}
|
||||
|
||||
private async onLoaded(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: false });
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as vscode from 'vscode';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
import { IDataComponent, IComponentSettings } from '../../interfaces';
|
||||
import { ModelArtifact } from '../prediction/modelArtifact';
|
||||
import * as utils from '../../../common/utils';
|
||||
|
||||
/**
|
||||
* View to render registered models table
|
||||
*/
|
||||
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<ImportedModel[]> {
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
private _modelBuilder: azdata.ModelBuilder | undefined;
|
||||
private _selectedModel: ImportedModel[] = [];
|
||||
private _loader: azdata.LoadingComponent | undefined;
|
||||
private _downloadedFile: ModelArtifact | undefined;
|
||||
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
|
||||
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
|
||||
|
||||
/**
|
||||
* Creates new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._modelBuilder = modelBuilder;
|
||||
let columns = [
|
||||
{ // Name
|
||||
displayName: constants.modelName,
|
||||
ariaLabel: constants.modelName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Created
|
||||
displayName: constants.modelCreated,
|
||||
ariaLabel: constants.modelCreated,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
];
|
||||
if (this._settings.editable) {
|
||||
columns.push(
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
);
|
||||
}
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: columns,
|
||||
data: [],
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
this._loader = modelBuilder.loadingComponent()
|
||||
.withItem(this._table)
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.component) {
|
||||
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component });
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this.component) {
|
||||
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the data in the component
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
await this.onLoading();
|
||||
if (this._table) {
|
||||
let models: ImportedModel[] | undefined;
|
||||
|
||||
if (this.importTable) {
|
||||
models = await this.listModels(this.importTable);
|
||||
} else {
|
||||
this.showErrorMessage('No import table');
|
||||
}
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (models) {
|
||||
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
this.onModelSelected();
|
||||
await this.onLoaded();
|
||||
}
|
||||
|
||||
public async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
}
|
||||
}
|
||||
|
||||
public async onLoaded(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: false });
|
||||
}
|
||||
}
|
||||
|
||||
private createTableRow(model: ImportedModel): any[] {
|
||||
let row: any[] = [model.modelName, model.created];
|
||||
if (this._modelBuilder) {
|
||||
const selectButton = this.createSelectButton(model);
|
||||
if (selectButton) {
|
||||
row.push(selectButton);
|
||||
}
|
||||
const editButtons = this.createEditButtons(model);
|
||||
if (editButtons && editButtons.length > 0) {
|
||||
row = row.concat(editButtons);
|
||||
}
|
||||
}
|
||||
|
||||
return row;
|
||||
}
|
||||
|
||||
private createSelectButton(model: ImportedModel): azdata.Component | undefined {
|
||||
let selectModelButton: azdata.Component | undefined = undefined;
|
||||
if (this._modelBuilder && this._settings.selectable) {
|
||||
|
||||
let onSelectItem = (checked: boolean) => {
|
||||
if (!this._settings.multiSelect) {
|
||||
this._selectedModel = [];
|
||||
}
|
||||
const foundItem = this._selectedModel.find(x => x === model);
|
||||
if (checked && !foundItem) {
|
||||
this._selectedModel.push(model);
|
||||
} else if (foundItem) {
|
||||
this._selectedModel = this._selectedModel.filter(x => x !== model);
|
||||
}
|
||||
this.onModelSelected();
|
||||
};
|
||||
if (this._settings.multiSelect) {
|
||||
const checkbox = this._modelBuilder.checkBox().withProperties({
|
||||
name: 'amlModel',
|
||||
value: model.id,
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: false
|
||||
}).component();
|
||||
checkbox.onChanged(() => {
|
||||
onSelectItem(checkbox.checked || false);
|
||||
});
|
||||
selectModelButton = checkbox;
|
||||
} else {
|
||||
const radioButton = this._modelBuilder.radioButton().withProperties({
|
||||
name: 'amlModel',
|
||||
value: model.id,
|
||||
width: 15,
|
||||
height: 15,
|
||||
checked: false
|
||||
}).component();
|
||||
radioButton.onDidClick(() => {
|
||||
onSelectItem(radioButton.checked || false);
|
||||
});
|
||||
selectModelButton = radioButton;
|
||||
}
|
||||
}
|
||||
return selectModelButton;
|
||||
}
|
||||
|
||||
private createEditButtons(model: ImportedModel): azdata.Component[] | undefined {
|
||||
let dropButton: azdata.ButtonComponent | undefined = undefined;
|
||||
let editButton: azdata.ButtonComponent | undefined = undefined;
|
||||
if (this._modelBuilder && this._settings.editable) {
|
||||
dropButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
title: constants.deleteTitle,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/delete.svg')
|
||||
},
|
||||
width: 15,
|
||||
height: 15
|
||||
}).component();
|
||||
dropButton.onDidClick(async () => {
|
||||
try {
|
||||
const confirm = await utils.promptConfirm(constants.confirmDeleteModel(model.modelName), this._apiWrapper);
|
||||
if (confirm) {
|
||||
await this.sendDataRequest(DeleteModelEventName, model);
|
||||
if (this.parent) {
|
||||
await this.parent?.refresh();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`);
|
||||
}
|
||||
});
|
||||
|
||||
editButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
title: constants.deleteTitle,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/edit.svg')
|
||||
},
|
||||
width: 15,
|
||||
height: 15
|
||||
}).component();
|
||||
editButton.onDidClick(async () => {
|
||||
await this.sendDataRequest(EditModelEventName, model);
|
||||
});
|
||||
}
|
||||
return editButton && dropButton ? [editButton, dropButton] : undefined;
|
||||
}
|
||||
|
||||
private async onModelSelected(): Promise<void> {
|
||||
this._onModelSelectionChanged.fire();
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
this._downloadedFile = undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ImportedModel[] | undefined {
|
||||
return this._selectedModel;
|
||||
}
|
||||
|
||||
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
|
||||
if (!this._downloadedFile && this.data && this.data.length > 0) {
|
||||
this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data[0]));
|
||||
}
|
||||
return this._downloadedFile;
|
||||
}
|
||||
|
||||
/**
|
||||
* disposes the view
|
||||
*/
|
||||
public async disposeComponent(): Promise<void> {
|
||||
if (this._downloadedFile) {
|
||||
await this._downloadedFile.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { ModelViewBase, UpdateModelEventName } from '../modelViewBase';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { DialogView } from '../../dialogView';
|
||||
import { ModelDetailsEditPage } from './modelDetailsEditPage';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* Dialog to render registered model views
|
||||
*/
|
||||
export class EditModelDialog extends ModelViewBase {
|
||||
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
root: string,
|
||||
private _parentView: ModelViewBase | undefined,
|
||||
private _model: ImportedModel) {
|
||||
super(apiWrapper, root);
|
||||
this.dialogView = new DialogView(this._apiWrapper);
|
||||
}
|
||||
public dialogView: DialogView;
|
||||
public editModelPage: ModelDetailsEditPage | undefined;
|
||||
|
||||
/**
|
||||
* Opens a dialog to edit models.
|
||||
*/
|
||||
public open(): void {
|
||||
|
||||
this.editModelPage = new ModelDetailsEditPage(this._apiWrapper, this, this._model);
|
||||
|
||||
let registerModelButton = this._apiWrapper.createButton(constants.extLangSaveButtonText);
|
||||
registerModelButton.onClick(async () => {
|
||||
if (this.editModelPage) {
|
||||
const valid = await this.editModelPage.validate();
|
||||
if (valid) {
|
||||
try {
|
||||
await this.sendDataRequest(UpdateModelEventName, this.editModelPage?.data);
|
||||
this.showInfoMessage(constants.modelUpdatedSuccessfully);
|
||||
if (this._parentView) {
|
||||
await this._parentView.refresh();
|
||||
}
|
||||
} catch (error) {
|
||||
this.showInfoMessage(`${constants.modelUpdateFailedError} ${constants.getErrorMessage(error)}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let dialog = this.dialogView.createDialog(constants.editModelTitle, [this.editModelPage]);
|
||||
dialog.customButtons = [registerModelButton];
|
||||
this.mainViewPanel = dialog;
|
||||
dialog.okButton.hidden = true;
|
||||
dialog.cancelButton.label = constants.extLangDoneButtonText;
|
||||
|
||||
dialog.registerCloseValidator(() => {
|
||||
return false; // Blocks Enter key from closing dialog.
|
||||
});
|
||||
|
||||
this._apiWrapper.openDialog(dialog);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the tabs for given provider Id
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this.dialogView) {
|
||||
this.dialogView.refresh();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { ModelSourcesComponent } from '../modelSourcesComponent';
|
||||
import { LocalModelsComponent } from '../localModelsComponent';
|
||||
import { AzureModelsComponent } from '../azureModelsComponent';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { WizardView } from '../../wizardView';
|
||||
import { ModelSourcePage } from '../modelSourcePage';
|
||||
import { ModelDetailsPage } from '../modelDetailsPage';
|
||||
import { ModelBrowsePage } from '../modelBrowsePage';
|
||||
import { ModelImportLocationPage } from './modelImportLocationPage';
|
||||
|
||||
/**
|
||||
* Wizard to register a model
|
||||
*/
|
||||
export class ImportModelWizard extends ModelViewBase {
|
||||
|
||||
public modelSourcePage: ModelSourcePage | undefined;
|
||||
public modelBrowsePage: ModelBrowsePage | undefined;
|
||||
public modelDetailsPage: ModelDetailsPage | undefined;
|
||||
public modelImportTargetPage: ModelImportLocationPage | undefined;
|
||||
public wizardView: WizardView | undefined;
|
||||
private _parentView: ModelViewBase | undefined;
|
||||
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
root: string,
|
||||
parent?: ModelViewBase) {
|
||||
super(apiWrapper, root);
|
||||
this._parentView = parent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a dialog to manage packages used by notebooks.
|
||||
*/
|
||||
public async open(): Promise<void> {
|
||||
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
|
||||
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
|
||||
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this);
|
||||
this.modelImportTargetPage = new ModelImportLocationPage(this._apiWrapper, this);
|
||||
this.wizardView = new WizardView(this._apiWrapper);
|
||||
|
||||
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelImportTargetPage, this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]);
|
||||
|
||||
this.mainViewPanel = wizard;
|
||||
wizard.doneButton.label = constants.azureRegisterModel;
|
||||
wizard.generateScriptButton.hidden = true;
|
||||
wizard.displayPageTitles = true;
|
||||
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
|
||||
let validated: boolean = true;
|
||||
if (pageInfo.newPage > pageInfo.lastPage) {
|
||||
validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
|
||||
}
|
||||
if (validated && pageInfo.newPage === undefined) {
|
||||
wizard.cancelButton.enabled = false;
|
||||
wizard.backButton.enabled = false;
|
||||
let result = await this.registerModel();
|
||||
wizard.cancelButton.enabled = true;
|
||||
wizard.backButton.enabled = true;
|
||||
if (this._parentView) {
|
||||
this._parentView.importTable = this.importTable;
|
||||
await this._parentView.refresh();
|
||||
}
|
||||
return result;
|
||||
|
||||
}
|
||||
return validated;
|
||||
});
|
||||
|
||||
await wizard.open();
|
||||
}
|
||||
|
||||
public get modelResources(): ModelSourcesComponent | undefined {
|
||||
return this.modelSourcePage?.modelResources;
|
||||
}
|
||||
|
||||
public get localModelsComponent(): LocalModelsComponent | undefined {
|
||||
return this.modelBrowsePage?.localModelsComponent;
|
||||
}
|
||||
|
||||
public get azureModelsComponent(): AzureModelsComponent | undefined {
|
||||
return this.modelBrowsePage?.azureModelsComponent;
|
||||
}
|
||||
|
||||
private async registerModel(): Promise<boolean> {
|
||||
try {
|
||||
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
|
||||
await this.importLocalModel(this.modelsViewData);
|
||||
} else {
|
||||
await this.importAzureModel(this.modelsViewData);
|
||||
}
|
||||
await this.storeImportConfigTable();
|
||||
this.showInfoMessage(constants.modelRegisteredSuccessfully);
|
||||
return true;
|
||||
} catch (error) {
|
||||
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh the pages
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.wizardView?.refresh();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { CurrentModelsComponent } from './currentModelsComponent';
|
||||
|
||||
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { DialogView } from '../../dialogView';
|
||||
|
||||
/**
|
||||
* Dialog to render registered model views
|
||||
*/
|
||||
export class ManageModelsDialog extends ModelViewBase {
|
||||
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
root: string) {
|
||||
super(apiWrapper, root);
|
||||
this.dialogView = new DialogView(this._apiWrapper);
|
||||
}
|
||||
public dialogView: DialogView;
|
||||
public currentLanguagesTab: CurrentModelsComponent | undefined;
|
||||
|
||||
/**
|
||||
* Opens a dialog to manage packages used by notebooks.
|
||||
*/
|
||||
public open(): void {
|
||||
|
||||
this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this, {
|
||||
editable: true,
|
||||
selectable: false
|
||||
});
|
||||
|
||||
let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle);
|
||||
registerModelButton.onClick(async () => {
|
||||
await this.sendDataRequest(RegisterModelEventName, this.currentLanguagesTab?.modelTable?.importTable);
|
||||
});
|
||||
|
||||
let dialog = this.dialogView.createDialog(constants.registerModelTitle, [this.currentLanguagesTab]);
|
||||
dialog.customButtons = [registerModelButton];
|
||||
this.mainViewPanel = dialog;
|
||||
dialog.okButton.hidden = true;
|
||||
dialog.cancelButton.label = constants.extLangDoneButtonText;
|
||||
|
||||
dialog.registerCloseValidator(() => {
|
||||
return false; // Blocks Enter key from closing dialog.
|
||||
});
|
||||
|
||||
this._apiWrapper.openDialog(dialog);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the tabs for given provider Id
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this.dialogView) {
|
||||
this.dialogView.refresh();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<ImportedModel> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _nameComponent: azdata.InputBoxComponent | undefined;
|
||||
private _descriptionComponent: azdata.InputBoxComponent | undefined;
|
||||
private _createdComponent: azdata.Component | undefined;
|
||||
private _deployedComponent: azdata.Component | undefined;
|
||||
private _frameworkComponent: azdata.Component | undefined;
|
||||
private _frameworkVersionComponent: azdata.Component | undefined;
|
||||
/**
|
||||
* Creates a new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._createdComponent = modelBuilder.text().withProperties({
|
||||
value: this._model.created
|
||||
}).component();
|
||||
this._deployedComponent = modelBuilder.text().withProperties({
|
||||
value: this._model.deploymentTime
|
||||
}).component();
|
||||
this._frameworkComponent = modelBuilder.text().withProperties({
|
||||
value: this._model.framework
|
||||
}).component();
|
||||
this._frameworkVersionComponent = modelBuilder.text().withProperties({
|
||||
value: this._model.frameworkVersion
|
||||
}).component();
|
||||
this._nameComponent = modelBuilder.inputBox().withProperties({
|
||||
width: this.componentMaxLength,
|
||||
value: this._model.modelName
|
||||
}).component();
|
||||
this._descriptionComponent = modelBuilder.inputBox().withProperties({
|
||||
width: this.componentMaxLength,
|
||||
value: this._model.description,
|
||||
multiline: true,
|
||||
height: 50
|
||||
}).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
component: this._nameComponent
|
||||
},
|
||||
{
|
||||
title: '',
|
||||
component: this._descriptionComponent
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
|
||||
formBuilder.addFormItems([{
|
||||
title: constants.modelName,
|
||||
component: this._nameComponent
|
||||
}, {
|
||||
title: constants.modelCreated,
|
||||
component: this._createdComponent
|
||||
},
|
||||
{
|
||||
title: constants.modelDeployed,
|
||||
component: this._deployedComponent
|
||||
}, {
|
||||
title: constants.modelFramework,
|
||||
component: this._frameworkComponent
|
||||
}, {
|
||||
title: constants.modelFrameworkVersion,
|
||||
component: this._frameworkVersionComponent
|
||||
}, {
|
||||
title: constants.modelDescription,
|
||||
component: this._descriptionComponent
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelCreated,
|
||||
component: this._createdComponent
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelCreated,
|
||||
component: this._frameworkComponent
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelCreated,
|
||||
component: this._frameworkVersionComponent
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelCreated,
|
||||
component: this._deployedComponent
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelName,
|
||||
component: this._nameComponent
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.modelDescription,
|
||||
component: this._descriptionComponent
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the created component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ImportedModel | undefined {
|
||||
let model = Object.assign({}, this._model);
|
||||
model.modelName = this._nameComponent?.value || '';
|
||||
model.description = this._descriptionComponent?.value || '';
|
||||
return model;
|
||||
}
|
||||
|
||||
/**
|
||||
* loads data in the components
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
}
|
||||
|
||||
/**
|
||||
* refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../../interfaces';
|
||||
import { ImportedModel } from '../../../modelManagement/interfaces';
|
||||
import { ModelDetailsComponent } from './modelDetailsComponent';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ModelDetailsEditPage extends ModelViewBase implements IPageView, IDataComponent<ImportedModel> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public modelDetailsComponent: ModelDetailsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.modelDetailsComponent = new ModelDetailsComponent(this._apiWrapper, this, this._model);
|
||||
|
||||
this.modelDetailsComponent.registerComponent(modelBuilder);
|
||||
this.modelDetailsComponent.addComponents(this._formBuilder);
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ImportedModel | undefined {
|
||||
return this.modelDetailsComponent?.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this.modelDetailsComponent) {
|
||||
await this.modelDetailsComponent.refresh();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.modelImportTargetPageTitle;
|
||||
}
|
||||
|
||||
public async disposePage(): Promise<void> {
|
||||
}
|
||||
|
||||
public async validate(): Promise<boolean> {
|
||||
let validated = false;
|
||||
|
||||
if (this.data?.modelName) {
|
||||
validated = true;
|
||||
} else {
|
||||
this.showErrorMessage(constants.modelNameRequiredError);
|
||||
}
|
||||
return validated;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../../interfaces';
|
||||
import { TableSelectionComponent } from '../tableSelectionComponent';
|
||||
import { DatabaseTable } from '../../../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ModelImportLocationPage extends ModelViewBase implements IPageView, IDataComponent<DatabaseTable> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public tableSelectionComponent: TableSelectionComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, true);
|
||||
this.tableSelectionComponent.onSelectedChanged(async () => {
|
||||
await this.onTableSelected();
|
||||
});
|
||||
this.tableSelectionComponent.registerComponent(modelBuilder);
|
||||
this.tableSelectionComponent.addComponents(this._formBuilder);
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
private async onTableSelected(): Promise<void> {
|
||||
if (this.tableSelectionComponent?.data) {
|
||||
this.importTable = this.tableSelectionComponent?.data;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): DatabaseTable | undefined {
|
||||
return this.tableSelectionComponent?.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this.tableSelectionComponent) {
|
||||
await this.tableSelectionComponent.refresh();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.modelImportTargetPageTitle;
|
||||
}
|
||||
|
||||
public async disposePage(): Promise<void> {
|
||||
}
|
||||
|
||||
public async validate(): Promise<boolean> {
|
||||
let validated = false;
|
||||
|
||||
if (this.data?.databaseName && this.data?.tableName) {
|
||||
validated = true;
|
||||
validated = await this.verifyImportConfigTable(this.data);
|
||||
if (!validated) {
|
||||
this.showErrorMessage(constants.invalidImportTableSchemaError(this.data?.databaseName, this.data?.tableName));
|
||||
}
|
||||
} else {
|
||||
this.showErrorMessage(constants.invalidImportTableError(this.data?.databaseName, this.data?.tableName));
|
||||
}
|
||||
return validated;
|
||||
}
|
||||
}
|
||||
207
extensions/machine-learning/src/views/models/modelBrowsePage.ts
Normal file
207
extensions/machine-learning/src/views/models/modelBrowsePage.ts
Normal file
@@ -0,0 +1,207 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, ModelSourceType, ModelViewData } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { LocalModelsComponent } from './localModelsComponent';
|
||||
import { AzureModelsComponent } from './azureModelsComponent';
|
||||
import * as utils from '../../common/utils';
|
||||
import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _title: string = constants.localModelPageTitle;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public localModelsComponent: LocalModelsComponent | undefined;
|
||||
public azureModelsComponent: AzureModelsComponent | undefined;
|
||||
public registeredModelsComponent: CurrentModelsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this, this._multiSelect);
|
||||
this.localModelsComponent.registerComponent(modelBuilder);
|
||||
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect);
|
||||
this.azureModelsComponent.registerComponent(modelBuilder);
|
||||
this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, {
|
||||
selectable: true,
|
||||
multiSelect: this._multiSelect,
|
||||
editable: false
|
||||
});
|
||||
this.registeredModelsComponent.registerComponent(modelBuilder);
|
||||
this.refresh();
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ModelViewData[] {
|
||||
return this.modelsViewData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._formBuilder) {
|
||||
if (this.modelSourceType === ModelSourceType.Local) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||
this.localModelsComponent.addComponents(this._formBuilder);
|
||||
await this.localModelsComponent.refresh();
|
||||
}
|
||||
|
||||
} else if (this.modelSourceType === ModelSourceType.Azure) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||
this.azureModelsComponent.addComponents(this._formBuilder);
|
||||
this.registeredModelsComponent.removeComponents(this._formBuilder);
|
||||
await this.azureModelsComponent.refresh();
|
||||
}
|
||||
|
||||
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
|
||||
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
|
||||
this.localModelsComponent.removeComponents(this._formBuilder);
|
||||
this.azureModelsComponent.removeComponents(this._formBuilder);
|
||||
this.registeredModelsComponent.addComponents(this._formBuilder);
|
||||
await this.registeredModelsComponent.refresh();
|
||||
}
|
||||
}
|
||||
}
|
||||
this.loadTitle();
|
||||
}
|
||||
|
||||
private loadTitle(): void {
|
||||
if (this.modelSourceType === ModelSourceType.Local) {
|
||||
this._title = constants.localModelPageTitle;
|
||||
} else if (this.modelSourceType === ModelSourceType.Azure) {
|
||||
this._title = constants.azureModelPageTitle;
|
||||
|
||||
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
|
||||
this._title = constants.importedModelsPageTitle;
|
||||
} else {
|
||||
this._title = constants.modelSourcePageTitle;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
this.loadTitle();
|
||||
return this._title;
|
||||
}
|
||||
|
||||
public validate(): Promise<boolean> {
|
||||
let validated = false;
|
||||
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
|
||||
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
|
||||
|
||||
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
|
||||
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0;
|
||||
|
||||
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
|
||||
validated = this.registeredModelsComponent.data !== undefined && this.registeredModelsComponent.data.length > 0;
|
||||
}
|
||||
if (!validated) {
|
||||
this.showErrorMessage(constants.invalidModelToSelectError);
|
||||
}
|
||||
return Promise.resolve(validated);
|
||||
}
|
||||
|
||||
public onEnter(): Promise<void> {
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
public async onLeave(): Promise<void> {
|
||||
this.modelsViewData = [];
|
||||
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
|
||||
if (this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0) {
|
||||
this.modelsViewData = this.localModelsComponent.data.map(x => {
|
||||
const fileName = utils.getFileName(x);
|
||||
return {
|
||||
modelData: x,
|
||||
modelDetails: {
|
||||
modelName: fileName,
|
||||
fileName: fileName
|
||||
},
|
||||
targetImportTable: this.importTable
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
|
||||
if (this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0) {
|
||||
this.modelsViewData = this.azureModelsComponent.data.map(x => {
|
||||
return {
|
||||
modelData: {
|
||||
account: x.account,
|
||||
subscription: x.subscription,
|
||||
group: x.group,
|
||||
workspace: x.workspace,
|
||||
model: x.model
|
||||
},
|
||||
modelDetails: {
|
||||
modelName: x.model?.name || '',
|
||||
fileName: x.model?.name,
|
||||
framework: x.model?.framework,
|
||||
frameworkVersion: x.model?.frameworkVersion,
|
||||
created: x.model?.createdTime
|
||||
},
|
||||
targetImportTable: this.importTable
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
|
||||
if (this.registeredModelsComponent.data !== undefined) {
|
||||
this.modelsViewData = this.registeredModelsComponent.data.map(x => {
|
||||
return {
|
||||
modelData: x,
|
||||
modelDetails: {
|
||||
modelName: ''
|
||||
},
|
||||
targetImportTable: this.importTable
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async disposePage(): Promise<void> {
|
||||
if (this.azureModelsComponent) {
|
||||
await this.azureModelsComponent.disposeComponent();
|
||||
|
||||
}
|
||||
if (this.registeredModelsComponent) {
|
||||
await this.registeredModelsComponent.disposeComponent();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, ModelViewData } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { ModelsDetailsTableComponent } from './modelsDetailsTableComponent';
|
||||
|
||||
/**
|
||||
* View to pick model details
|
||||
*/
|
||||
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public modelDetails: ModelsDetailsTableComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.modelDetails = new ModelsDetailsTableComponent(this._apiWrapper, modelBuilder, this);
|
||||
this.modelDetails.registerComponent(modelBuilder);
|
||||
this.modelDetails.addComponents(this._formBuilder);
|
||||
this.refresh();
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ModelViewData[] | undefined {
|
||||
return this.modelDetails?.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this.modelDetails) {
|
||||
await this.modelDetails.refresh();
|
||||
}
|
||||
}
|
||||
|
||||
public async onEnter(): Promise<void> {
|
||||
await this.refresh();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.modelDetailsPageTitle;
|
||||
}
|
||||
|
||||
public validate(): Promise<boolean> {
|
||||
if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.modelName)) {
|
||||
return Promise.resolve(true);
|
||||
} else {
|
||||
this.showErrorMessage(constants.modelNameRequiredError);
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,425 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { ImportedModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import { DeployedModelService } from '../../modelManagement/deployedModelService';
|
||||
import { ManageModelsDialog } from './manageModels/manageModelsDialog';
|
||||
import {
|
||||
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
|
||||
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
|
||||
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
|
||||
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName
|
||||
} from './modelViewBase';
|
||||
import { ControllerBase } from '../controllerBase';
|
||||
import { ImportModelWizard } from './manageModels/importModelWizard';
|
||||
import * as fs from 'fs';
|
||||
import * as constants from '../../common/constants';
|
||||
import { PredictWizard } from './prediction/predictWizard';
|
||||
import { AzureModelResource } from '../interfaces';
|
||||
import { PredictService } from '../../prediction/predictService';
|
||||
import { EditModelDialog } from './manageModels/editModelDialog';
|
||||
|
||||
/**
|
||||
* Model management UI controller
|
||||
*/
|
||||
export class ModelManagementController extends ControllerBase {
|
||||
|
||||
/**
|
||||
* Creates new instance
|
||||
*/
|
||||
constructor(
|
||||
apiWrapper: ApiWrapper,
|
||||
private _root: string,
|
||||
private _amlService: AzureModelRegistryService,
|
||||
private _registeredModelService: DeployedModelService,
|
||||
private _predictService: PredictService) {
|
||||
super(apiWrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the dialog for model registration
|
||||
* @param parent parent if the view is opened from another view
|
||||
* @param controller controller
|
||||
* @param apiWrapper apiWrapper
|
||||
* @param root root folder path
|
||||
*/
|
||||
public async registerModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
|
||||
controller = controller || this;
|
||||
apiWrapper = apiWrapper || this._apiWrapper;
|
||||
root = root || this._root;
|
||||
let view = new ImportModelWizard(apiWrapper, root, parent);
|
||||
if (importTable) {
|
||||
view.importTable = importTable;
|
||||
} else {
|
||||
view.importTable = await controller._registeredModelService.getRecentImportTable();
|
||||
}
|
||||
|
||||
controller.registerEvents(view);
|
||||
|
||||
// Open view
|
||||
//
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the dialog to edit model
|
||||
*/
|
||||
public async editModel(model: ImportedModel, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
|
||||
controller = controller || this;
|
||||
apiWrapper = apiWrapper || this._apiWrapper;
|
||||
root = root || this._root;
|
||||
let view = new EditModelDialog(apiWrapper, root, parent, model);
|
||||
|
||||
controller.registerEvents(view);
|
||||
|
||||
// Open view
|
||||
//
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the wizard for prediction
|
||||
*/
|
||||
public async predictModel(): Promise<ModelViewBase> {
|
||||
|
||||
let view = new PredictWizard(this._apiWrapper, this._root);
|
||||
view.importTable = await this._registeredModelService.getRecentImportTable();
|
||||
|
||||
this.registerEvents(view);
|
||||
view.on(LoadModelParametersEventName, async () => {
|
||||
const modelArtifact = await view.getModelFileName();
|
||||
await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
|
||||
modelArtifact?.filePath);
|
||||
});
|
||||
|
||||
// Open view
|
||||
//
|
||||
await view.open();
|
||||
await view.refresh();
|
||||
return view;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Register events in the main view
|
||||
* @param view main view
|
||||
*/
|
||||
public registerEvents(view: ModelViewBase): void {
|
||||
|
||||
// Register events
|
||||
//
|
||||
super.registerEvents(view);
|
||||
view.on(ListAccountsEventName, async () => {
|
||||
await this.executeAction(view, ListAccountsEventName, this.getAzureAccounts, this._amlService);
|
||||
});
|
||||
view.on(ListSubscriptionsEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListSubscriptionsEventName, this.getAzureSubscriptions, this._amlService, azureArgs.account);
|
||||
});
|
||||
view.on(ListWorkspacesEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListWorkspacesEventName, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group);
|
||||
});
|
||||
view.on(ListGroupsEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListGroupsEventName, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription);
|
||||
});
|
||||
view.on(ListAzureModelsEventName, async (arg) => {
|
||||
let azureArgs = <AzureResourceEventArgs>arg;
|
||||
await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService
|
||||
, azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace);
|
||||
});
|
||||
view.on(ListModelsEventName, async (args) => {
|
||||
const table = <DatabaseTable>args;
|
||||
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService, table);
|
||||
});
|
||||
view.on(RegisterLocalModelEventName, async (arg) => {
|
||||
let models = <ModelViewData[]>arg;
|
||||
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models);
|
||||
view.refresh();
|
||||
});
|
||||
view.on(RegisterModelEventName, async (args) => {
|
||||
const importTable = <DatabaseTable>args;
|
||||
await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root);
|
||||
});
|
||||
view.on(EditModelEventName, async (args) => {
|
||||
const model = <ImportedModel>args;
|
||||
await this.executeAction(view, EditModelEventName, this.editModel, model, view, this, this._apiWrapper, this._root);
|
||||
});
|
||||
view.on(UpdateModelEventName, async (args) => {
|
||||
const model = <ImportedModel>args;
|
||||
await this.executeAction(view, UpdateModelEventName, this.updateModel, this._registeredModelService, model);
|
||||
});
|
||||
view.on(DeleteModelEventName, async (args) => {
|
||||
const model = <ImportedModel>args;
|
||||
await this.executeAction(view, DeleteModelEventName, this.deleteModel, this._registeredModelService, model);
|
||||
});
|
||||
view.on(RegisterAzureModelEventName, async (arg) => {
|
||||
let models = <ModelViewData[]>arg;
|
||||
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
|
||||
models);
|
||||
});
|
||||
view.on(DownloadAzureModelEventName, async (arg) => {
|
||||
let registerArgs = <AzureModelResource>arg;
|
||||
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
|
||||
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
|
||||
});
|
||||
view.on(ListDatabaseNamesEventName, async () => {
|
||||
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
|
||||
});
|
||||
view.on(ListTableNamesEventName, async (arg) => {
|
||||
let dbName = <string>arg;
|
||||
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
|
||||
});
|
||||
view.on(ListColumnNamesEventName, async (arg) => {
|
||||
let tableColumnsArgs = <DatabaseTable>arg;
|
||||
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
|
||||
tableColumnsArgs);
|
||||
});
|
||||
view.on(PredictModelEventName, async (arg) => {
|
||||
let predictArgs = <PredictModelEventArgs>arg;
|
||||
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
|
||||
predictArgs, predictArgs.model, predictArgs.filePath);
|
||||
});
|
||||
view.on(DownloadRegisteredModelEventName, async (arg) => {
|
||||
let model = <ImportedModel>arg;
|
||||
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
|
||||
model);
|
||||
});
|
||||
view.on(StoreImportTableEventName, async (arg) => {
|
||||
let importTable = <DatabaseTable>arg;
|
||||
await this.executeAction(view, StoreImportTableEventName, this.storeImportTable, this._registeredModelService,
|
||||
importTable);
|
||||
});
|
||||
view.on(VerifyImportTableEventName, async (arg) => {
|
||||
let importTable = <DatabaseTable>arg;
|
||||
await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService,
|
||||
importTable);
|
||||
});
|
||||
view.on(SourceModelSelectedEventName, async (arg) => {
|
||||
view.modelSourceType = <ModelSourceType>arg;
|
||||
await view.refresh();
|
||||
});
|
||||
view.on(SignInToAzureEventName, async () => {
|
||||
await this.executeAction(view, SignInToAzureEventName, this.signInToAzure, this._amlService);
|
||||
await view.refresh();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens the dialog for model management
|
||||
*/
|
||||
public async manageRegisteredModels(importTable?: DatabaseTable): Promise<ModelViewBase> {
|
||||
let view = new ManageModelsDialog(this._apiWrapper, this._root);
|
||||
|
||||
if (importTable) {
|
||||
view.importTable = importTable;
|
||||
} else {
|
||||
view.importTable = await this._registeredModelService.getRecentImportTable();
|
||||
}
|
||||
|
||||
// Register events
|
||||
//
|
||||
this.registerEvents(view);
|
||||
|
||||
// Open view
|
||||
//
|
||||
view.open();
|
||||
return view;
|
||||
}
|
||||
|
||||
private async signInToAzure(service: AzureModelRegistryService): Promise<void> {
|
||||
return await service.signInToAzure();
|
||||
}
|
||||
|
||||
private async getAzureAccounts(service: AzureModelRegistryService): Promise<azdata.Account[]> {
|
||||
return await service.getAccounts();
|
||||
}
|
||||
|
||||
private async getAzureSubscriptions(service: AzureModelRegistryService, account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[] | undefined> {
|
||||
return await service.getSubscriptions(account);
|
||||
}
|
||||
|
||||
private async getAzureGroups(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[] | undefined> {
|
||||
return await service.getGroups(account, subscription);
|
||||
}
|
||||
|
||||
private async getWorkspaces(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[] | undefined> {
|
||||
if (!account || !subscription) {
|
||||
return [];
|
||||
}
|
||||
return await service.getWorkspaces(account, subscription, group);
|
||||
}
|
||||
|
||||
private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise<ImportedModel[]> {
|
||||
return registeredModelService.getDeployedModels(table);
|
||||
}
|
||||
|
||||
private async getAzureModels(
|
||||
service: AzureModelRegistryService,
|
||||
account: azdata.Account | undefined,
|
||||
subscription: azureResource.AzureResourceSubscription | undefined,
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
workspace: Workspace | undefined): Promise<WorkspaceModel[]> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace) {
|
||||
return [];
|
||||
}
|
||||
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
|
||||
}
|
||||
|
||||
private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise<void> {
|
||||
if (models) {
|
||||
await Promise.all(models.map(async (model) => {
|
||||
if (model && model.targetImportTable) {
|
||||
const localModel = <string>model.modelData;
|
||||
if (localModel) {
|
||||
await service.deployLocalModel(localModel, model.modelDetails, model.targetImportTable);
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}));
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
|
||||
private async updateModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
|
||||
if (model) {
|
||||
await service.updateModel(model);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
|
||||
private async deleteModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
|
||||
if (model) {
|
||||
await service.deleteModel(model);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
|
||||
private async registerAzureModel(
|
||||
azureService: AzureModelRegistryService,
|
||||
service: DeployedModelService,
|
||||
models: ModelViewData[] | undefined): Promise<void> {
|
||||
if (!models) {
|
||||
throw Error(constants.invalidAzureResourceError);
|
||||
}
|
||||
|
||||
await Promise.all(models.map(async (model) => {
|
||||
if (model && model.targetImportTable) {
|
||||
const azureModel = <AzureModelResource>model.modelData;
|
||||
if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) {
|
||||
let filePath: string | undefined;
|
||||
try {
|
||||
const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group,
|
||||
azureModel.workspace, azureModel.model);
|
||||
if (filePath) {
|
||||
await service.deployLocalModel(filePath, model.modelDetails, model.targetImportTable);
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
} finally {
|
||||
if (filePath) {
|
||||
await fs.promises.unlink(filePath);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private async getDatabaseList(predictService: PredictService): Promise<string[]> {
|
||||
return await predictService.getDatabaseList();
|
||||
}
|
||||
|
||||
private async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
|
||||
return await predictService.getTableList(databaseName);
|
||||
}
|
||||
|
||||
private async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
|
||||
return await predictService.getTableColumnsList(databaseTable);
|
||||
}
|
||||
|
||||
private async generatePredictScript(
|
||||
predictService: PredictService,
|
||||
predictParams: PredictParameters,
|
||||
registeredModel: ImportedModel | undefined,
|
||||
filePath: string | undefined
|
||||
): Promise<string> {
|
||||
if (!predictParams) {
|
||||
throw Error(constants.invalidModelToPredictError);
|
||||
}
|
||||
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
|
||||
return result;
|
||||
}
|
||||
|
||||
private async storeImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<void> {
|
||||
if (table) {
|
||||
await registeredModelService.storeRecentImportTable(table);
|
||||
} else {
|
||||
throw Error(constants.invalidImportTableError(undefined, undefined));
|
||||
}
|
||||
}
|
||||
|
||||
private async verifyImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<boolean> {
|
||||
if (table) {
|
||||
return await registeredModelService.verifyConfigTable(table);
|
||||
} else {
|
||||
throw Error(constants.invalidImportTableError(undefined, undefined));
|
||||
}
|
||||
}
|
||||
|
||||
private async downloadRegisteredModel(
|
||||
registeredModelService: DeployedModelService,
|
||||
model: ImportedModel | undefined): Promise<string> {
|
||||
if (!model) {
|
||||
throw Error(constants.invalidModelToPredictError);
|
||||
}
|
||||
return await registeredModelService.downloadModel(model);
|
||||
}
|
||||
|
||||
private async loadModelParameters(
|
||||
registeredModelService: DeployedModelService,
|
||||
model: string | undefined): Promise<ModelParameters | undefined> {
|
||||
if (!model) {
|
||||
return undefined;
|
||||
}
|
||||
return await registeredModelService.loadModelParameters(model);
|
||||
}
|
||||
|
||||
private async downloadAzureModel(
|
||||
azureService: AzureModelRegistryService,
|
||||
account: azdata.Account | undefined,
|
||||
subscription: azureResource.AzureResourceSubscription | undefined,
|
||||
resourceGroup: azureResource.AzureResource | undefined,
|
||||
workspace: Workspace | undefined,
|
||||
model: WorkspaceModel | undefined): Promise<string> {
|
||||
if (!account || !subscription || !resourceGroup || !workspace || !model) {
|
||||
throw Error(constants.invalidAzureResourceError);
|
||||
}
|
||||
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
|
||||
if (filePath) {
|
||||
return filePath;
|
||||
} else {
|
||||
throw Error(constants.invalidModelToRegisterError);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, ModelSourceType } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../interfaces';
|
||||
import { ModelSourcesComponent } from './modelSourcesComponent';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ModelSourcePage extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public modelResources: ModelSourcesComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
|
||||
this.modelResources.registerComponent(modelBuilder);
|
||||
this.modelResources.addComponents(this._formBuilder);
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ModelSourceType {
|
||||
return this.modelResources?.data || ModelSourceType.Local;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.modelSourcePageTitle;
|
||||
}
|
||||
|
||||
public async disposePage(): Promise<void> {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ModelSourcesComponent extends ModelViewBase implements IDataComponent<ModelSourceType> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _flexContainer: azdata.FlexContainer | undefined;
|
||||
private _amlModel: azdata.CardComponent | undefined;
|
||||
private _localModel: azdata.CardComponent | undefined;
|
||||
private _registeredModels: azdata.CardComponent | undefined;
|
||||
private _sourceType: ModelSourceType = ModelSourceType.Local;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
|
||||
this._localModel = modelBuilder.card()
|
||||
.withProperties({
|
||||
value: 'local',
|
||||
name: 'modelLocation',
|
||||
label: constants.localModelSource,
|
||||
selected: this._options[0] === ModelSourceType.Local,
|
||||
cardType: azdata.CardType.VerticalButton,
|
||||
width: 50
|
||||
}).component();
|
||||
this._amlModel = modelBuilder.card()
|
||||
.withProperties({
|
||||
value: 'aml',
|
||||
name: 'modelLocation',
|
||||
label: constants.azureModelSource,
|
||||
selected: this._options[0] === ModelSourceType.Azure,
|
||||
cardType: azdata.CardType.VerticalButton,
|
||||
width: 50
|
||||
}).component();
|
||||
|
||||
this._registeredModels = modelBuilder.card()
|
||||
.withProperties({
|
||||
value: 'registered',
|
||||
name: 'modelLocation',
|
||||
label: constants.registeredModelsSource,
|
||||
selected: this._options[0] === ModelSourceType.RegisteredModels,
|
||||
cardType: azdata.CardType.VerticalButton,
|
||||
width: 50
|
||||
}).component();
|
||||
|
||||
this._localModel.onCardSelectedChanged(() => {
|
||||
this._sourceType = ModelSourceType.Local;
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
if (this._amlModel && this._registeredModels) {
|
||||
this._amlModel.selected = false;
|
||||
this._registeredModels.selected = false;
|
||||
}
|
||||
});
|
||||
this._amlModel.onCardSelectedChanged(() => {
|
||||
this._sourceType = ModelSourceType.Azure;
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
if (this._localModel && this._registeredModels) {
|
||||
this._localModel.selected = false;
|
||||
this._registeredModels.selected = false;
|
||||
}
|
||||
});
|
||||
this._registeredModels.onCardSelectedChanged(() => {
|
||||
this._sourceType = ModelSourceType.RegisteredModels;
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
if (this._localModel && this._amlModel) {
|
||||
this._localModel.selected = false;
|
||||
this._amlModel.selected = false;
|
||||
}
|
||||
});
|
||||
let components: azdata.Component[] = [];
|
||||
|
||||
this._options.forEach(option => {
|
||||
switch (option) {
|
||||
case ModelSourceType.Local:
|
||||
if (this._localModel) {
|
||||
components.push(this._localModel);
|
||||
}
|
||||
break;
|
||||
case ModelSourceType.Azure:
|
||||
if (this._amlModel) {
|
||||
components.push(this._amlModel);
|
||||
}
|
||||
break;
|
||||
case ModelSourceType.RegisteredModels:
|
||||
if (this._registeredModels) {
|
||||
components.push(this._registeredModels);
|
||||
}
|
||||
break;
|
||||
}
|
||||
});
|
||||
this._sourceType = this._options[0];
|
||||
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
|
||||
|
||||
this._flexContainer = modelBuilder.flexContainer()
|
||||
.withLayout({
|
||||
flexFlow: 'row',
|
||||
justifyContent: 'space-between'
|
||||
}).withItems(components).component();
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: '',
|
||||
component: this._flexContainer
|
||||
}]).component();
|
||||
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flexContainer) {
|
||||
formBuilder.addFormItem({ title: '', component: this._flexContainer });
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._flexContainer) {
|
||||
formBuilder.removeFormItem({ title: '', component: this._flexContainer });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ModelSourceType {
|
||||
return this._sourceType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
}
|
||||
328
extensions/machine-learning/src/views/models/modelViewBase.ts
Normal file
328
extensions/machine-learning/src/views/models/modelViewBase.ts
Normal file
@@ -0,0 +1,328 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
|
||||
import { azureResource } from '../../typings/azure-resource';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import { ViewBase } from '../viewBase';
|
||||
import { ImportedModel, WorkspaceModel, ImportedModelDetails, ModelParameters } from '../../modelManagement/interfaces';
|
||||
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
|
||||
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
|
||||
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
|
||||
|
||||
|
||||
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
|
||||
}
|
||||
|
||||
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
|
||||
details?: ImportedModelDetails
|
||||
}
|
||||
|
||||
export interface PredictModelEventArgs extends PredictParameters {
|
||||
model?: ImportedModel;
|
||||
filePath?: string;
|
||||
}
|
||||
|
||||
|
||||
export enum ModelSourceType {
|
||||
Local,
|
||||
Azure,
|
||||
RegisteredModels
|
||||
}
|
||||
|
||||
export interface ModelViewData {
|
||||
modelFile?: string;
|
||||
modelData: AzureModelResource | string | ImportedModel;
|
||||
modelDetails?: ImportedModelDetails;
|
||||
targetImportTable?: DatabaseTable;
|
||||
}
|
||||
|
||||
// Event names
|
||||
//
|
||||
export const ListModelsEventName = 'listModels';
|
||||
export const ListAzureModelsEventName = 'listAzureModels';
|
||||
export const ListAccountsEventName = 'listAccounts';
|
||||
export const ListDatabaseNamesEventName = 'listDatabaseNames';
|
||||
export const ListTableNamesEventName = 'listTableNames';
|
||||
export const ListColumnNamesEventName = 'listColumnNames';
|
||||
export const ListSubscriptionsEventName = 'listSubscriptions';
|
||||
export const ListGroupsEventName = 'listGroups';
|
||||
export const ListWorkspacesEventName = 'listWorkspaces';
|
||||
export const RegisterLocalModelEventName = 'registerLocalModel';
|
||||
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
|
||||
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
|
||||
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
|
||||
export const PredictModelEventName = 'predictModel';
|
||||
export const RegisterModelEventName = 'registerModel';
|
||||
export const EditModelEventName = 'editModel';
|
||||
export const UpdateModelEventName = 'updateModel';
|
||||
export const DeleteModelEventName = 'deleteModel';
|
||||
export const SourceModelSelectedEventName = 'sourceModelSelected';
|
||||
export const LoadModelParametersEventName = 'loadModelParameters';
|
||||
export const StoreImportTableEventName = 'storeImportTable';
|
||||
export const VerifyImportTableEventName = 'verifyImportTable';
|
||||
export const SignInToAzureEventName = 'signInToAzure';
|
||||
|
||||
/**
|
||||
* Base class for all model management views
|
||||
*/
|
||||
export abstract class ModelViewBase extends ViewBase {
|
||||
|
||||
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
|
||||
private _modelsViewData: ModelViewData[] = [];
|
||||
private _importTable: DatabaseTable | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
|
||||
super(apiWrapper, root, parent);
|
||||
}
|
||||
|
||||
protected getEventNames(): string[] {
|
||||
return super.getEventNames().concat([ListModelsEventName,
|
||||
ListAzureModelsEventName,
|
||||
ListAccountsEventName,
|
||||
ListSubscriptionsEventName,
|
||||
ListGroupsEventName,
|
||||
ListWorkspacesEventName,
|
||||
RegisterLocalModelEventName,
|
||||
RegisterAzureModelEventName,
|
||||
RegisterModelEventName,
|
||||
SourceModelSelectedEventName,
|
||||
ListDatabaseNamesEventName,
|
||||
ListTableNamesEventName,
|
||||
ListColumnNamesEventName,
|
||||
PredictModelEventName,
|
||||
DownloadAzureModelEventName,
|
||||
DownloadRegisteredModelEventName,
|
||||
LoadModelParametersEventName,
|
||||
StoreImportTableEventName,
|
||||
VerifyImportTableEventName,
|
||||
EditModelEventName,
|
||||
UpdateModelEventName,
|
||||
DeleteModelEventName,
|
||||
SignInToAzureEventName]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parent view
|
||||
*/
|
||||
public get parent(): ModelViewBase | undefined {
|
||||
return this._parent ? <ModelViewBase>this._parent : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* list azure models
|
||||
*/
|
||||
public async listAzureModels(workspaceResource: AzureWorkspaceResource): Promise<WorkspaceModel[]> {
|
||||
const args: AzureResourceEventArgs = workspaceResource;
|
||||
return await this.sendDataRequest(ListAzureModelsEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* list registered models
|
||||
*/
|
||||
public async listModels(table: DatabaseTable): Promise<ImportedModel[]> {
|
||||
return await this.sendDataRequest(ListModelsEventName, table);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists azure accounts
|
||||
*/
|
||||
public async listAzureAccounts(): Promise<azdata.Account[]> {
|
||||
return await this.sendDataRequest(ListAccountsEventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists database names
|
||||
*/
|
||||
public async listDatabaseNames(): Promise<string[]> {
|
||||
return await this.sendDataRequest(ListDatabaseNamesEventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists table names
|
||||
*/
|
||||
public async listTableNames(dbName: string): Promise<DatabaseTable[]> {
|
||||
return await this.sendDataRequest(ListTableNamesEventName, dbName);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists column names
|
||||
*/
|
||||
public async listColumnNames(table: DatabaseTable): Promise<TableColumn[]> {
|
||||
return await this.sendDataRequest(ListColumnNamesEventName, table);
|
||||
}
|
||||
|
||||
/**
|
||||
* lists azure subscriptions
|
||||
* @param account azure account
|
||||
*/
|
||||
public async listAzureSubscriptions(account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[]> {
|
||||
const args: AzureResourceEventArgs = {
|
||||
account: account
|
||||
};
|
||||
return await this.sendDataRequest(ListSubscriptionsEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* registers local model
|
||||
* @param localFilePath local file path
|
||||
*/
|
||||
public async importLocalModel(models: ModelViewData[]): Promise<void> {
|
||||
return await this.sendDataRequest(RegisterLocalModelEventName, models);
|
||||
}
|
||||
|
||||
/**
|
||||
* downloads registered model
|
||||
* @param model model to download
|
||||
*/
|
||||
public async downloadRegisteredModel(model: ImportedModel | undefined): Promise<string> {
|
||||
return await this.sendDataRequest(DownloadRegisteredModelEventName, model);
|
||||
}
|
||||
|
||||
/**
|
||||
* download azure model
|
||||
* @param args azure resource
|
||||
*/
|
||||
public async downloadAzureModel(resource: AzureModelResource | undefined): Promise<string> {
|
||||
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads model parameters
|
||||
*/
|
||||
public async loadModelParameters(): Promise<ModelParameters | undefined> {
|
||||
return await this.sendDataRequest(LoadModelParametersEventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* registers azure model
|
||||
* @param args azure resource
|
||||
*/
|
||||
public async importAzureModel(models: ModelViewData[]): Promise<void> {
|
||||
return await this.sendDataRequest(RegisterAzureModelEventName, models);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stores the name of the table as recent config table for importing models
|
||||
*/
|
||||
public async storeImportConfigTable(): Promise<void> {
|
||||
await this.sendRequest(StoreImportTableEventName, this.importTable);
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies if table is valid to import models to
|
||||
*/
|
||||
public async verifyImportConfigTable(table: DatabaseTable): Promise<boolean> {
|
||||
return await this.sendDataRequest(VerifyImportTableEventName, table);
|
||||
}
|
||||
|
||||
/**
|
||||
* registers azure model
|
||||
* @param args azure resource
|
||||
*/
|
||||
public async generatePredictScript(model: ImportedModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> {
|
||||
const args: PredictModelEventArgs = Object.assign({}, params, {
|
||||
model: model,
|
||||
filePath: filePath,
|
||||
loadFromRegisteredModel: !filePath
|
||||
});
|
||||
return await this.sendDataRequest(PredictModelEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* list resource groups
|
||||
* @param account azure account
|
||||
* @param subscription azure subscription
|
||||
*/
|
||||
public async listAzureGroups(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[]> {
|
||||
const args: AzureResourceEventArgs = {
|
||||
account: account,
|
||||
subscription: subscription
|
||||
};
|
||||
return await this.sendDataRequest(ListGroupsEventName, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets model source type
|
||||
*/
|
||||
public set modelSourceType(value: ModelSourceType) {
|
||||
if (this.parent) {
|
||||
this.parent.modelSourceType = value;
|
||||
} else {
|
||||
this._modelSourceType = value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns model source type
|
||||
*/
|
||||
public get modelSourceType(): ModelSourceType {
|
||||
if (this.parent) {
|
||||
return this.parent.modelSourceType;
|
||||
} else {
|
||||
return this._modelSourceType;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets model data
|
||||
*/
|
||||
public set modelsViewData(value: ModelViewData[]) {
|
||||
if (this.parent) {
|
||||
this.parent.modelsViewData = value;
|
||||
} else {
|
||||
this._modelsViewData = value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns model data
|
||||
*/
|
||||
public get modelsViewData(): ModelViewData[] {
|
||||
if (this.parent) {
|
||||
return this.parent.modelsViewData;
|
||||
} else {
|
||||
return this._modelsViewData;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets import table
|
||||
*/
|
||||
public set importTable(value: DatabaseTable | undefined) {
|
||||
if (this.parent) {
|
||||
this.parent.importTable = value;
|
||||
} else {
|
||||
this._importTable = value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns import table
|
||||
*/
|
||||
public get importTable(): DatabaseTable | undefined {
|
||||
if (this.parent) {
|
||||
return this.parent.importTable;
|
||||
} else {
|
||||
return this._importTable;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* lists azure workspaces
|
||||
* @param account azure account
|
||||
* @param subscription azure subscription
|
||||
* @param group azure resource group
|
||||
*/
|
||||
public async listWorkspaces(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[]> {
|
||||
const args: AzureResourceEventArgs = {
|
||||
account: account,
|
||||
subscription: subscription,
|
||||
group: group
|
||||
};
|
||||
return await this.sendDataRequest(ListWorkspacesEventName, args);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase, ModelViewData } from './modelViewBase';
|
||||
import { ApiWrapper } from '../../common/apiWrapper';
|
||||
import * as constants from '../../common/constants';
|
||||
import { IDataComponent } from '../interfaces';
|
||||
|
||||
/**
|
||||
* View to pick local models file
|
||||
*/
|
||||
export class ModelsDetailsTableComponent extends ModelViewBase implements IDataComponent<ModelViewData[]> {
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
|
||||
/**
|
||||
* Creates new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register the components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: [
|
||||
{ // Name
|
||||
displayName: constants.modelFileName,
|
||||
ariaLabel: constants.modelFileName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Name
|
||||
displayName: constants.modelName,
|
||||
ariaLabel: constants.modelName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 150,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Created
|
||||
displayName: constants.modelDescription,
|
||||
ariaLabel: constants.modelDescription,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 100,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
],
|
||||
data: [],
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
|
||||
return this._table;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._table) {
|
||||
formBuilder.addFormItems([{
|
||||
title: '',
|
||||
component: this._table
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._table) {
|
||||
formBuilder.removeFormItem({
|
||||
title: '',
|
||||
component: this._table
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load data in the component
|
||||
* @param workspaceResource Azure workspace
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
|
||||
const models = this.modelsViewData;
|
||||
if (this._table && models) {
|
||||
|
||||
let tableData: any[][] = [];
|
||||
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
|
||||
this._table.data = tableData;
|
||||
}
|
||||
}
|
||||
|
||||
private createTableRow(model: ModelViewData | undefined): any[] {
|
||||
if (this._modelBuilder && model && model.modelDetails) {
|
||||
const nameComponent = this._modelBuilder.inputBox().withProperties({
|
||||
value: model.modelDetails.modelName,
|
||||
width: this.componentMaxLength - 100,
|
||||
required: true
|
||||
}).component();
|
||||
const descriptionComponent = this._modelBuilder.inputBox().withProperties({
|
||||
value: model.modelDetails.description,
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
descriptionComponent.onTextChanged(() => {
|
||||
if (model.modelDetails) {
|
||||
model.modelDetails.description = descriptionComponent.value;
|
||||
}
|
||||
});
|
||||
nameComponent.onTextChanged(() => {
|
||||
if (model.modelDetails) {
|
||||
model.modelDetails.modelName = nameComponent.value || '';
|
||||
}
|
||||
});
|
||||
let deleteButton = this._modelBuilder.button().withProperties({
|
||||
label: '',
|
||||
title: constants.deleteTitle,
|
||||
width: 15,
|
||||
height: 15,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
|
||||
light: this.asAbsolutePath('images/light/delete.svg')
|
||||
},
|
||||
}).component();
|
||||
deleteButton.onDidClick(async () => {
|
||||
this.modelsViewData = this.modelsViewData.filter(x => x !== model);
|
||||
await this.refresh();
|
||||
});
|
||||
return [model.modelDetails.fileName, nameComponent, descriptionComponent, deleteButton];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): ModelViewData[] {
|
||||
return this.modelsViewData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._table;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IPageView, IDataComponent } from '../../interfaces';
|
||||
import { InputColumnsComponent } from './inputColumnsComponent';
|
||||
import { OutputColumnsComponent } from './outputColumnsComponent';
|
||||
import { PredictParameters } from '../../../prediction/interfaces';
|
||||
|
||||
/**
|
||||
* View to pick model source
|
||||
*/
|
||||
export class ColumnsSelectionPage extends ModelViewBase implements IPageView, IDataComponent<PredictParameters> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _formBuilder: azdata.FormBuilder | undefined;
|
||||
public inputColumnsComponent: InputColumnsComponent | undefined;
|
||||
public outputColumnsComponent: OutputColumnsComponent | undefined;
|
||||
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param modelBuilder Register components
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._formBuilder = modelBuilder.formContainer();
|
||||
this.inputColumnsComponent = new InputColumnsComponent(this._apiWrapper, this);
|
||||
this.inputColumnsComponent.registerComponent(modelBuilder);
|
||||
this.inputColumnsComponent.addComponents(this._formBuilder);
|
||||
|
||||
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
|
||||
this.outputColumnsComponent.registerComponent(modelBuilder);
|
||||
this.outputColumnsComponent.addComponents(this._formBuilder);
|
||||
|
||||
this._form = this._formBuilder.component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictParameters | undefined {
|
||||
return this.inputColumnsComponent?.data && this.outputColumnsComponent?.data ?
|
||||
Object.assign({}, this.inputColumnsComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
|
||||
undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
if (this._formBuilder) {
|
||||
if (this.inputColumnsComponent) {
|
||||
await this.inputColumnsComponent.refresh();
|
||||
}
|
||||
if (this.outputColumnsComponent) {
|
||||
await this.outputColumnsComponent.refresh();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async onEnter(): Promise<void> {
|
||||
await this.inputColumnsComponent?.onLoading();
|
||||
await this.outputColumnsComponent?.onLoading();
|
||||
try {
|
||||
const modelParameters = await this.loadModelParameters();
|
||||
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
|
||||
this.inputColumnsComponent.modelParameters = modelParameters;
|
||||
this.outputColumnsComponent.modelParameters = modelParameters;
|
||||
await this.inputColumnsComponent.refresh();
|
||||
await this.outputColumnsComponent.refresh();
|
||||
}
|
||||
} catch (error) {
|
||||
this.showErrorMessage(constants.loadModelParameterFailedError, error);
|
||||
}
|
||||
await this.inputColumnsComponent?.onLoaded();
|
||||
await this.outputColumnsComponent?.onLoaded();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns page title
|
||||
*/
|
||||
public get title(): string {
|
||||
return constants.columnSelectionPageTitle;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { PredictColumn, DatabaseTable, TableColumn } from '../../../prediction/interfaces';
|
||||
import { ModelParameter, ModelParameters } from '../../../modelManagement/interfaces';
|
||||
|
||||
/**
|
||||
* View to render azure models in a table
|
||||
*/
|
||||
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
|
||||
|
||||
private _table: azdata.DeclarativeTableComponent | undefined;
|
||||
private _parameters: PredictColumn[] = [];
|
||||
private _loader: azdata.LoadingComponent;
|
||||
private _dataTypes: string[] = [
|
||||
'bigint',
|
||||
'int',
|
||||
'smallint',
|
||||
'real',
|
||||
'float',
|
||||
'varchar(MAX)',
|
||||
'bit'
|
||||
];
|
||||
|
||||
|
||||
/**
|
||||
* Creates a view to render azure models in a table
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _forInput: boolean = true) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
this._loader = this.registerComponent(this._modelBuilder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.LoadingComponent {
|
||||
let columnHeader: azdata.DeclarativeTableColumn[];
|
||||
if (this._forInput) {
|
||||
columnHeader = [
|
||||
{ // Action
|
||||
displayName: constants.columnName,
|
||||
ariaLabel: constants.columnName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Name
|
||||
displayName: '',
|
||||
ariaLabel: '',
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Name
|
||||
displayName: constants.inputName,
|
||||
ariaLabel: constants.inputName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 120,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
];
|
||||
} else {
|
||||
columnHeader = [
|
||||
{ // Name
|
||||
displayName: constants.outputName,
|
||||
ariaLabel: constants.outputName,
|
||||
valueType: azdata.DeclarativeDataType.string,
|
||||
isReadOnly: true,
|
||||
width: 200,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: constants.displayName,
|
||||
ariaLabel: constants.displayName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
},
|
||||
{ // Action
|
||||
displayName: constants.dataTypeName,
|
||||
ariaLabel: constants.dataTypeName,
|
||||
valueType: azdata.DeclarativeDataType.component,
|
||||
isReadOnly: true,
|
||||
width: 50,
|
||||
headerCssStyles: {
|
||||
...constants.cssStyles.tableHeader
|
||||
},
|
||||
rowCssStyles: {
|
||||
...constants.cssStyles.tableRow
|
||||
},
|
||||
}
|
||||
];
|
||||
}
|
||||
this._table = modelBuilder.declarativeTable()
|
||||
|
||||
.withProperties<azdata.DeclarativeTableProperties>(
|
||||
{
|
||||
columns: columnHeader,
|
||||
data: [],
|
||||
ariaLabel: constants.mlsConfigTitle
|
||||
})
|
||||
.component();
|
||||
this._loader = modelBuilder.loadingComponent()
|
||||
.withItem(this._table)
|
||||
.withProperties({
|
||||
loading: true
|
||||
}).component();
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
public async onLoading(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: true });
|
||||
}
|
||||
}
|
||||
|
||||
public async onLoaded(): Promise<void> {
|
||||
if (this._loader) {
|
||||
await this._loader.updateProperties({ loading: false });
|
||||
}
|
||||
}
|
||||
|
||||
public get component(): azdata.Component {
|
||||
return this._loader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load data in the component
|
||||
* @param workspaceResource Azure workspace
|
||||
*/
|
||||
public async loadInputs(modelParameters: ModelParameters | undefined, table: DatabaseTable): Promise<void> {
|
||||
await this.onLoading();
|
||||
this._parameters = [];
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (this._table) {
|
||||
if (this._forInput) {
|
||||
const columns = await this.listColumnNames(table);
|
||||
if (modelParameters?.inputs && columns) {
|
||||
tableData = tableData.concat(modelParameters.inputs.map(input => this.createInputTableRow(input, columns)));
|
||||
}
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
await this.onLoaded();
|
||||
}
|
||||
|
||||
public async loadOutputs(modelParameters: ModelParameters | undefined): Promise<void> {
|
||||
this.onLoading();
|
||||
this._parameters = [];
|
||||
let tableData: any[][] = [];
|
||||
|
||||
if (this._table) {
|
||||
if (!this._forInput) {
|
||||
if (modelParameters?.outputs && this._dataTypes) {
|
||||
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
|
||||
}
|
||||
}
|
||||
|
||||
this._table.data = tableData;
|
||||
}
|
||||
this.onLoaded();
|
||||
}
|
||||
|
||||
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
|
||||
if (this._modelBuilder) {
|
||||
|
||||
let nameInput = this._modelBuilder.dropDown().withProperties({
|
||||
values: dataTypes,
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
const name = modelParameter.name;
|
||||
const dataType = dataTypes.find(x => x === modelParameter.type);
|
||||
if (dataType) {
|
||||
nameInput.value = dataType;
|
||||
}
|
||||
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
|
||||
|
||||
nameInput.onValueChanged(() => {
|
||||
const value = <string>nameInput.value;
|
||||
if (value !== modelParameter.type) {
|
||||
let selectedRow = this._parameters.find(x => x.paramName === name);
|
||||
if (selectedRow) {
|
||||
selectedRow.dataType = value;
|
||||
}
|
||||
}
|
||||
});
|
||||
let displayNameInput = this._modelBuilder.inputBox().withProperties({
|
||||
value: name,
|
||||
width: 200
|
||||
}).component();
|
||||
displayNameInput.onTextChanged(() => {
|
||||
let selectedRow = this._parameters.find(x => x.paramName === name);
|
||||
if (selectedRow) {
|
||||
selectedRow.columnName = displayNameInput.value || name;
|
||||
}
|
||||
});
|
||||
return [`${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
private createInputTableRow(modelParameter: ModelParameter, columns: TableColumn[] | undefined): any[] {
|
||||
if (this._modelBuilder && columns) {
|
||||
const values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; });
|
||||
let nameInput = this._modelBuilder.dropDown().withProperties({
|
||||
values: values,
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
const name = modelParameter.name;
|
||||
let column = values.find(x => x.name === modelParameter.name);
|
||||
if (!column) {
|
||||
column = values[0];
|
||||
}
|
||||
nameInput.value = column;
|
||||
|
||||
this._parameters.push({ columnName: column.name, paramName: name });
|
||||
|
||||
nameInput.onValueChanged(() => {
|
||||
const selectedColumn = nameInput.value;
|
||||
const value = selectedColumn ? (<azdata.CategoryValue>selectedColumn).name : undefined;
|
||||
|
||||
let selectedRow = this._parameters.find(x => x.paramName === name);
|
||||
if (selectedRow) {
|
||||
selectedRow.columnName = value || '';
|
||||
}
|
||||
});
|
||||
const label = this._modelBuilder.inputBox().withProperties({
|
||||
value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`,
|
||||
enabled: false,
|
||||
width: this.componentMaxLength
|
||||
}).component();
|
||||
const image = this._modelBuilder.image().withProperties({
|
||||
width: 50,
|
||||
height: 50,
|
||||
iconPath: {
|
||||
dark: this.asAbsolutePath('images/arrow.svg'),
|
||||
light: this.asAbsolutePath('images/arrow.svg')
|
||||
},
|
||||
iconWidth: 20,
|
||||
iconHeight: 20,
|
||||
title: 'maps'
|
||||
}).component();
|
||||
return [nameInput, image, label];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictColumn[] | undefined {
|
||||
return this._parameters;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as azdata from 'azdata';
|
||||
import { ModelViewBase } from '../modelViewBase';
|
||||
import { ApiWrapper } from '../../../common/apiWrapper';
|
||||
import * as constants from '../../../common/constants';
|
||||
import { IDataComponent } from '../../interfaces';
|
||||
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
|
||||
import { ModelParameters } from '../../../modelManagement/interfaces';
|
||||
import { ColumnsTable } from './columnsTable';
|
||||
import { TableSelectionComponent } from '../tableSelectionComponent';
|
||||
|
||||
/**
|
||||
* View to render filters to pick an azure resource
|
||||
*/
|
||||
export class InputColumnsComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
|
||||
|
||||
private _form: azdata.FormContainer | undefined;
|
||||
private _tableSelectionComponent: TableSelectionComponent | undefined;
|
||||
private _columns: ColumnsTable | undefined;
|
||||
private _modelParameters: ModelParameters | undefined;
|
||||
|
||||
/**
|
||||
* Creates a new view
|
||||
*/
|
||||
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
|
||||
super(apiWrapper, parent.root, parent);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register components
|
||||
* @param modelBuilder model builder
|
||||
*/
|
||||
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
|
||||
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
|
||||
this._tableSelectionComponent.registerComponent(modelBuilder);
|
||||
this._tableSelectionComponent.onSelectedChanged(async () => {
|
||||
await this.onTableSelected();
|
||||
});
|
||||
|
||||
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
|
||||
|
||||
this._form = modelBuilder.formContainer().withFormItems([{
|
||||
title: constants.inputColumns,
|
||||
component: this._columns.component
|
||||
}]).component();
|
||||
return this._form;
|
||||
}
|
||||
|
||||
public addComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
|
||||
formBuilder.addFormItems([{
|
||||
title: '',
|
||||
component: this._tableSelectionComponent.component
|
||||
}, {
|
||||
title: constants.inputColumns,
|
||||
component: this._columns.component
|
||||
}]);
|
||||
}
|
||||
}
|
||||
|
||||
public removeComponents(formBuilder: azdata.FormBuilder) {
|
||||
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
|
||||
formBuilder.removeFormItem({
|
||||
title: '',
|
||||
component: this._tableSelectionComponent.component
|
||||
});
|
||||
formBuilder.removeFormItem({
|
||||
title: constants.inputColumns,
|
||||
component: this._columns.component
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the created component
|
||||
*/
|
||||
public get component(): azdata.Component | undefined {
|
||||
return this._form;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns selected data
|
||||
*/
|
||||
public get data(): PredictInputParameters | undefined {
|
||||
return Object.assign({}, this.databaseTable, {
|
||||
inputColumns: this.columnNames
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* loads data in the components
|
||||
*/
|
||||
public async loadData(): Promise<void> {
|
||||
if (this._tableSelectionComponent) {
|
||||
this._tableSelectionComponent.refresh();
|
||||
}
|
||||
}
|
||||
|
||||
public set modelParameters(value: ModelParameters) {
|
||||
this._modelParameters = value;
|
||||
}
|
||||
|
||||
public async onLoading(): Promise<void> {
|
||||
if (this._columns) {
|
||||
await this._columns.onLoading();
|
||||
}
|
||||
}
|
||||
|
||||
public async onLoaded(): Promise<void> {
|
||||
if (this._columns) {
|
||||
await this._columns.onLoaded();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* refreshes the view
|
||||
*/
|
||||
public async refresh(): Promise<void> {
|
||||
await this.loadData();
|
||||
}
|
||||
|
||||
private async onTableSelected(): Promise<void> {
|
||||
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
|
||||
}
|
||||
|
||||
private get databaseTable(): DatabaseTable {
|
||||
let selectedItem = this._tableSelectionComponent?.data;
|
||||
return {
|
||||
databaseName: selectedItem?.databaseName,
|
||||
tableName: selectedItem?.tableName,
|
||||
schema: selectedItem?.schema
|
||||
};
|
||||
}
|
||||
|
||||
private get columnNames(): PredictColumn[] | undefined {
|
||||
return this._columns?.data;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import * as utils from '../../../common/utils';
|
||||
|
||||
/**
|
||||
* Wizard to register a model
|
||||
*/
|
||||
export class ModelArtifact {
|
||||
|
||||
/**
|
||||
* Creates new model artifact
|
||||
*/
|
||||
constructor(private _filePath: string, private _deleteAtClose: boolean = true) {
|
||||
}
|
||||
|
||||
public get filePath(): string {
|
||||
return this._filePath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the artifact and disposes the resources
|
||||
*/
|
||||
public async close(): Promise<void> {
|
||||
if (this._deleteAtClose) {
|
||||
try {
|
||||
await utils.deleteFile(this._filePath);
|
||||
} catch {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user