ML - dashboard icons and links (#10153)

* ML - dashboard icons and links
This commit is contained in:
Leila Lali
2020-04-28 21:21:30 -07:00
committed by GitHub
parent 046995f2a5
commit 04af41c424
145 changed files with 387 additions and 134 deletions

View File

@@ -0,0 +1,136 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
/**
* Wrapper class to act as a facade over VSCode and Data APIs and allow us to test / mock callbacks into
* this API from our code
*/
export class ApiWrapper {
public createOutputChannel(name: string): vscode.OutputChannel {
return vscode.window.createOutputChannel(name);
}
public createTerminalWithOptions(options: vscode.TerminalOptions): vscode.Terminal {
return vscode.window.createTerminal(options);
}
public getCurrentConnection(): Thenable<azdata.connection.ConnectionProfile> {
return azdata.connection.getCurrentConnection();
}
public getCredentials(connectionId: string): Thenable<{ [name: string]: string }> {
return azdata.connection.getCredentials(connectionId);
}
public registerCommand(command: string, callback: (...args: any[]) => any, thisArg?: any): vscode.Disposable {
return vscode.commands.registerCommand(command, callback, thisArg);
}
public executeCommand<T>(command: string, ...rest: any[]): Thenable<T | undefined> {
return vscode.commands.executeCommand(command, ...rest);
}
public registerTaskHandler(taskId: string, handler: (profile: azdata.IConnectionProfile) => void): void {
azdata.tasks.registerTask(taskId, handler);
}
public getUriForConnection(connectionId: string): Thenable<string> {
return azdata.connection.getUriForConnection(connectionId);
}
public getProvider<T extends azdata.DataProvider>(providerId: string, providerType: azdata.DataProviderType): T {
return azdata.dataprotocol.getProvider<T>(providerId, providerType);
}
public showErrorMessage(message: string, ...items: string[]): Thenable<string | undefined> {
return vscode.window.showErrorMessage(message, ...items);
}
public showInfoMessage(message: string, ...items: string[]): Thenable<string | undefined> {
return vscode.window.showInformationMessage(message, ...items);
}
public showOpenDialog(options: vscode.OpenDialogOptions): Thenable<vscode.Uri[] | undefined> {
return vscode.window.showOpenDialog(options);
}
public startBackgroundOperation(operationInfo: azdata.BackgroundOperationInfo): void {
azdata.tasks.startBackgroundOperation(operationInfo);
}
public openExternal(target: vscode.Uri): Thenable<boolean> {
return vscode.env.openExternal(target);
}
public getExtension(extensionId: string): vscode.Extension<any> | undefined {
return vscode.extensions.getExtension(extensionId);
}
public getConfiguration(section?: string, resource?: vscode.Uri | null): vscode.WorkspaceConfiguration {
return vscode.workspace.getConfiguration(section, resource);
}
public createTab(title: string): azdata.window.DialogTab {
return azdata.window.createTab(title);
}
public createModelViewDialog(title: string, dialogName?: string, isWide?: boolean): azdata.window.Dialog {
return azdata.window.createModelViewDialog(title, dialogName, isWide);
}
public createWizard(title: string): azdata.window.Wizard {
return azdata.window.createWizard(title);
}
public createWizardPage(title: string): azdata.window.WizardPage {
return azdata.window.createWizardPage(title);
}
public openDialog(dialog: azdata.window.Dialog): void {
return azdata.window.openDialog(dialog);
}
public getAllAccounts(): Thenable<azdata.Account[]> {
return azdata.accounts.getAllAccounts();
}
public getSecurityToken(account: azdata.Account, resource: azdata.AzureResource): Thenable<{ [key: string]: any }> {
return azdata.accounts.getSecurityToken(account, resource);
}
public showQuickPick<T extends vscode.QuickPickItem>(items: T[] | Thenable<T[]>, options?: vscode.QuickPickOptions, token?: vscode.CancellationToken): Thenable<T | undefined> {
return vscode.window.showQuickPick(items, options, token);
}
public listDatabases(connectionId: string): Thenable<string[]> {
return azdata.connection.listDatabases(connectionId);
}
public openTextDocument(options?: { language?: string; content?: string; }): Thenable<vscode.TextDocument> {
return vscode.workspace.openTextDocument(options);
}
public connect(fileUri: string, connectionId: string): Thenable<void> {
return azdata.queryeditor.connect(fileUri, connectionId);
}
public runQuery(fileUri: string, options?: Map<string, string>, runCurrentQuery?: boolean): void {
azdata.queryeditor.runQuery(fileUri, options, runCurrentQuery);
}
public showTextDocument(uri: vscode.Uri, options?: vscode.TextDocumentShowOptions): Thenable<vscode.TextEditor> {
return vscode.window.showTextDocument(uri, options);
}
public createButton(label: string, position?: azdata.window.DialogButtonPosition): azdata.window.Button {
return azdata.window.createButton(label, position);
}
public registerWidget(widgetId: string, handler: (view: azdata.ModelView) => void): void {
azdata.ui.registerModelViewProvider(widgetId, handler);
}
}

View File

@@ -0,0 +1,235 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as nls from 'vscode-nls';
const localize = nls.loadMessageBundle();
export const winPlatform = 'win32';
export const pythonBundleVersion = '0.0.1';
export const managePackagesCommand = 'jupyter.cmd.managePackages';
export const pythonLanguageName = 'Python';
export const rLanguageName = 'R';
export const rLPackagedFolderName = 'r_packages';
export const mlEnableMlsCommand = 'mls.command.enableMls';
export const mlDisableMlsCommand = 'mls.command.disableMls';
export const extensionOutputChannel = 'Machine Learning';
export const notebookExtensionName = 'Microsoft.notebook';
export const azureSubscriptionsCommand = 'azure.accounts.getSubscriptions';
export const azureResourceGroupsCommand = 'azure.accounts.getResourceGroups';
export const signInToAzureCommand = 'azure.resource.signin';
// Tasks, commands
//
export const mlManageLanguagesCommand = 'mls.command.manageLanguages';
export const mlsPredictModelCommand = 'mls.command.predictModel';
export const mlManageModelsCommand = 'mls.command.manageModels';
export const mlImportModelCommand = 'mls.command.importModel';
export const mlManagePackagesCommand = 'mls.command.managePackages';
export const mlsDependenciesCommand = 'mls.command.dependencies';
export const notebookCommandNew = 'notebook.command.new';
// Configurations
//
export const mlsConfigKey = 'machineLearningServices';
export const pythonPathConfigKey = 'pythonPath';
export const pythonEnabledConfigKey = 'enablePython';
export const rEnabledConfigKey = 'enableR';
export const registeredModelsTableName = 'registeredModelsTableName';
export const rPathConfigKey = 'rPath';
// Localized texts
//
export const msgYes = localize('msgYes', "Yes");
export const msgNo = localize('msgNo', "No");
export const managePackageCommandError = localize('mls.managePackages.error', "Either no connection is available or the server does not have external script enabled.");
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Installing package management dependencies");
export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Installing model management dependencies");
export const noResultError = localize('mls.noResultError', "No Result returned");
export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed");
export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that.");
export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script.");
export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action.");
export function confirmInstallPythonPackages(packages: string): string {
return localize('mls.installDependencies.confirmInstallPythonPackages'
, "The following Python packages are required to install: {0}. Are you sure you want to install?", packages);
}
export function confirmDeleteModel(modelName: string): string {
return localize('models.confirmDeleteModel'
, "Are you sure you want to delete model '{0}?", modelName);
}
export const installDependenciesPackages = localize('mls.installDependencies.packages', "Installing required packages ...");
export const installDependenciesPackagesAlreadyInstalled = localize('mls.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
export function installDependenciesGetPackagesError(err: string): string { return localize('mls.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
export const noConnectionError = localize('mls.packageManager.NoConnection', "No connection selected");
export const notebookExtensionNotLoaded = localize('mls.notebookExtensionNotLoaded', "Notebook extension is not loaded");
export const mssqlExtensionNotLoaded = localize('mls.mssqlExtensionNotLoaded', "MSSQL extension is not loaded");
export const mlsEnabledMessage = localize('mls.enabledMessage', "Machine Learning Services Enabled");
export const mlsConfigUpdateFailed = localize('mls.configUpdateFailed', "Failed to modify Machine Learning Services configurations");
export const mlsEnableButtonTitle = localize('mls.enableButtonTitle', "Enable");
export const mlsDisableButtonTitle = localize('mls.disableButtonTitle', "Disable");
export const mlsConfigTitle = localize('mls.configTitle', "Config");
export const mlsConfigStatus = localize('mls.configStatus', "Enabled");
export const mlsConfigAction = localize('mls.configAction', "Action");
export const mlsExternalExecuteScriptTitle = localize('mls.externalExecuteScriptTitle', "External Execute Script");
export const mlsPythonLanguageTitle = localize('mls.pythonLanguageTitle', "Python");
export const mlsRLanguageTitle = localize('mls.rLanguageTitle', "R");
export const downloadError = localize('mls.downloadError', "Error while downloading");
export function invalidModelIdError(modelUrl: string | undefined): string { return localize('mls.invalidModelIdError', "Invalid model id. model url: {0}", modelUrl || ''); }
export function noArtifactError(modelUrl: string | undefined): string { return localize('mls.noArtifactError', "Model doesn't have any artifact. model url: {0}", modelUrl || ''); }
export const downloadingProgress = localize('mls.downloadingProgress', "Downloading");
export const pythonConfigError = localize('mls.pythonConfigError', "Python executable is not configured");
export const rConfigError = localize('mls.rConfigError', "R executable is not configured");
export const installingDependencies = localize('mls.installingDependencies', "Installing dependencies ...");
export const resourceNotFoundError = localize('mls.resourceNotFound', "Could not find the specified resource");
export const latestVersion = localize('mls.latestVersion', "Latest");
export const localhost = 'localhost';
export function httpGetRequestError(code: number, message: string): string {
return localize('mls.httpGetRequestError', "Package info request failed with error: {0} {1}",
code,
message);
}
export function getErrorMessage(error: Error): string { return localize('azure.resource.error', "Error: {0}", error?.message || error?.toString()); }
export const notSupportedEventArg = localize('notSupportedEventArg', "Not supported event args");
export const extLangInstallTabTitle = localize('extLang.installTabTitle', "Installed");
export const extLangLanguageCreatedDate = localize('extLang.languageCreatedDate', "Installed");
export const extLangLanguagePlatform = localize('extLang.languagePlatform', "Platform");
export const deleteTitle = localize('extLang.delete', "Delete");
export const extLangInstallButtonText = localize('extLang.installButtonText', "Install");
export const extLangCancelButtonText = localize('extLang.CancelButtonText', "Cancel");
export const extLangDoneButtonText = localize('extLang.DoneButtonText', "Close");
export const extLangOkButtonText = localize('extLang.OkButtonText', "OK");
export const extLangSaveButtonText = localize('extLang.SaveButtonText', "Save");
export const extLangLanguageName = localize('extLang.languageName', "Name");
export const extLangNewLanguageTabTitle = localize('extLang.newLanguageTabTitle', "Add new");
export const extLangFileBrowserTabTitle = localize('extLang.fileBrowserTabTitle', "File Browser");
export const extLangDialogTitle = localize('extLang.DialogTitle', "Languages");
export const extLangTarget = localize('extLang.Target', "Target");
export const extLangLocal = localize('extLang.Local', "localhost");
export const extLangExtensionFilePath = localize('extLang.extensionFilePath', "Language extension path");
export const extLangExtensionFileLocation = localize('extLang.extensionFileLocation', "Language extension location");
export const extLangExtensionFileName = localize('extLang.extensionFileName', "Extension file Name");
export const extLangEnvVariables = localize('extLang.envVariables', "Environment variables");
export const extLangParameters = localize('extLang.parameters', "Parameters");
export const extLangSelectedPath = localize('extLang.selectedPath', "Selected Path");
export const extLangInstallFailedError = localize('extLang.installFailedError', "Failed to install language");
export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language");
export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model");
export const databaseName = localize('databaseName', "Database name");
export const tableName = localize('tableName', "Table name");
export const modelName = localize('models.name', "Name");
export const modelFileName = localize('models.fileName', "File");
export const modelDescription = localize('models.description', "Description");
export const modelCreated = localize('models.created', "Date created");
export const modelDeployed = localize('models.deployed', "Date deployed");
export const modelFramework = localize('models.framework', "Framework");
export const modelFrameworkVersion = localize('models.frameworkVersion', "Framework version");
export const modelVersion = localize('models.version', "Version");
export const browseModels = localize('models.browseButton', "...");
export const azureAccount = localize('models.azureAccount', "Azure account");
export const azureSignIn = localize('models.azureSignIn', "Sign in to Azure");
export const columnDatabase = localize('predict.columnDatabase', "Target database");
export const columnTable = localize('predict.columnTable', "Target table");
export const inputColumns = localize('predict.inputColumns', "Model input mapping");
export const outputColumns = localize('predict.outputColumns', "Model output");
export const columnName = localize('predict.columnName', "Target columns");
export const dataTypeName = localize('predict.dataTypeName', "Type");
export const displayName = localize('predict.displayName', "Display name");
export const inputName = localize('predict.inputName', "Required model input features");
export const outputName = localize('predict.outputName', "Name");
export const azureSubscription = localize('models.azureSubscription', "Azure subscription");
export const azureGroup = localize('models.azureGroup', "Azure resource group");
export const azureModelWorkspace = localize('models.azureModelWorkspace', "Azure ML workspace");
export const azureModelFilter = localize('models.azureModelFilter', "Filter");
export const azureModels = localize('models.azureModels', "Models");
export const azureModelsTitle = localize('models.azureModelsTitle', "Azure models");
export const localModelsTitle = localize('models.localModelsTitle', "Local models");
export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location");
export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?");
export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Where do you want import models to?");
export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input");
export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details");
export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file");
export const currentModelsTitle = localize('models.currentModelsTitle', "Models");
export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy");
export const predictModel = localize('models.predictModel', "Predict");
export const registerModelTitle = localize('models.RegisterWizard', "Import models");
export const importModelTitle = localize('models.importModelTitle', "Import models");
export const editModelTitle = localize('models.editModelTitle', "Edit model");
export const importModelDesc = localize('models.importModelDesc', "Build, import and expose a machine learning model");
export const makePredictionTitle = localize('models.makePredictionTitle', "Make predictions");
export const makePredictionDesc = localize('models.makePredictionDesc', "Generates a predicted value or scores using a managed model");
export const createNotebookTitle = localize('models.createNotebookTitle', "Create notebook");
export const createNotebookDesc = localize('models.createNotebookDesc', "Run experiments and create models");
export const modelRegisteredSuccessfully = localize('models.modelRegisteredSuccessfully', "Model registered successfully");
export const modelUpdatedSuccessfully = localize('models.modelUpdatedSuccessfully', "Model updated successfully");
export const modelFailedToRegister = localize('models.modelFailedToRegistered', "Model failed to register");
export const localModelSource = localize('models.localModelSource', "File upload");
export const localModelPageTitle = localize('models.localModelPageTitle', "Upload model file");
export const azureModelSource = localize('models.azureModelSource', "Azure Machine Learning");
export const azureModelPageTitle = localize('models.azureModelPageTitle', "Import from Azure Machine Learning");
export const importedModelsPageTitle = localize('models.importedModelsPageTitle', "Select imported model");
export const registeredModelsSource = localize('models.registeredModelsSource', "Imported models");
export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskName', "Downloading Model from Azure");
export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource");
export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register");
export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict");
export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model");
export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table");
export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required.");
export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model");
export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); }
export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
export function invalidImportTableSchemaError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableSchemaError', "Table schema is not supported for model import. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); }
export const loadModelParameterFailedError = localize('models.loadModelParameterFailedError', "Failed to load model parameters'");
export const unsupportedModelParameterType = localize('models.unsupportedModelParameterType', "unsupported");
export const dashboardTitle = localize('dashboardTitle', "Machine Learning");
export const dashboardDesc = localize('dashboardDesc', "Machine Learning for SQL Databases");
export const dashboardLinksTitle = localize('dashboardLinksTitle', "Useful links");
export const dashboardVideoLinksTitle = localize('dashboardVideoLinksTitle', "Video tutorials");
export const showMoreTitle = localize('showMoreTitle', "Show more");
export const showLessTitle = localize('showLessTitle', "Show less");
export const learnMoreTitle = localize('learnMoreTitle', "Learn more");
export const sqlMlDocTitle = localize('sqlMlDocTitle', "SQL machine learning documentation");
export const sqlMlDocDesc = localize('sqlMlDocDesc', "Learn how to use machine learning in SQL Server and SQL on Azure, to run Python and R scripts on relational data.");
export const sqlMlsDocTitle = localize('sqlMlsDocTitle', "SQL Server Machine Learning Services (Python and R)");
export const sqlMlsDocDesc = localize('sqlMlsDocDesc', "Get started with Machine Learning Services on SQL Server and how to install it on Windows and Linux.");
export const sqlMlsAzureDocTitle = localize('sqlMlsAzureDocTitle', "Machine Learning Services in Azure SQL Managed Instance (preview)");
export const sqlMlsAzureDocDesc = localize('sqlMlsAzureDocDesc', "Get started with Machine Learning Services in Azure SQL Managed Instances.");
export const mlsInstallOdbcDocTitle = localize('mlsInstallObdcDocTitle', "Install the Microsoft ODBC driver for SQL Server");
export const mlsInstallOdbcDocDesc = localize('mlsInstallOdbcDocDesc', "This document explains how to install the Microsoft ODBC Driver for SQL Server.");
// Links
//
export const mlsDocuments = 'https://docs.microsoft.com/sql/advanced-analytics/?view=sql-server-ver15';
export const odbcDriverWindowsDocuments = 'https://docs.microsoft.com/sql/connect/odbc/windows/microsoft-odbc-driver-for-sql-server-on-windows?view=sql-server-ver15';
export const odbcDriverLinuxDocuments = 'https://docs.microsoft.com/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-for-sql-server?view=sql-server-ver15';
export const mlDocLink = 'https://docs.microsoft.com/sql/machine-learning/';
export const mlsDocLink = 'https://docs.microsoft.com/sql/machine-learning/what-is-sql-server-machine-learning';
export const mlsAzureDocLink = 'https://docs.microsoft.com/azure/sql-database/sql-database-managed-instance-machine-learning-services-overview';
export const installMlsWindowsDocs = 'https://docs.microsoft.com/sql/advanced-analytics/install/sql-machine-learning-services-windows-install?view=sql-server-ver15';
// CSS Styles
//
export namespace cssStyles {
export const title = { 'font-size': '14px', 'font-weight': '600' };
export const tableHeader = { 'text-align': 'left', 'font-weight': 'bold', 'text-transform': 'uppercase', 'font-size': '10px', 'user-select': 'text', 'border': 'none' };
export const tableRow = { 'border-top': 'solid 1px #ccc', 'border-bottom': 'solid 1px #ccc', 'border-left': 'none', 'border-right': 'none' };
export const hyperlink = { 'user-select': 'text', 'color': '#0078d4', 'text-decoration': 'underline', 'cursor': 'pointer' };
export const text = { 'margin-block-start': '0px', 'margin-block-end': '0px' };
export const overflowEllipsisText = { ...text, 'overflow': 'hidden', 'text-overflow': 'ellipsis' };
export const nonSelectableText = { ...cssStyles.text, 'user-select': 'none' };
export const tabHeaderText = { 'margin-block-start': '2px', 'margin-block-end': '0px', 'user-select': 'none' };
export const selectedResourceHeaderTab = { 'font-weight': 'bold', 'color': '' };
export const unselectedResourceHeaderTab = { 'font-weight': '', 'color': '#0078d4' };
export const selectedTabDiv = { 'border-bottom': '2px solid #000' };
export const unselectedTabDiv = { 'border-bottom': '1px solid #ccc' };
export const lastUpdatedText = { ...text, 'color': '#595959' };
export const errorText = { ...text, 'color': 'red' };
}

View File

@@ -0,0 +1,45 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
export class EventEmitterCollection extends vscode.Disposable {
private _events: Map<string, vscode.EventEmitter<any>[]> = new Map<string, vscode.EventEmitter<any>[]>();
/**
*
*/
constructor() {
super(() => this.dispose());
}
public on(evt: string, listener: (e: any) => any, thisArgs?: any) {
if (!this._events.has(evt)) {
this._events.set(evt, []);
}
let eventEmitter = new vscode.EventEmitter<any>();
eventEmitter.event(listener, thisArgs);
this._events.get(evt)?.push(eventEmitter);
return this;
}
public fire(evt: string, arg?: any) {
if (!this._events.has(evt)) {
this._events.set(evt, []);
}
this._events.get(evt)?.forEach(eventEmitter => {
eventEmitter.fire(arg);
});
}
public dispose(): any {
this._events.forEach(events => {
events.forEach(event => {
event.dispose();
});
});
}
}

View File

@@ -0,0 +1,80 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as fs from 'fs';
import * as request from 'request';
import * as constants from './constants';
const DownloadTimeout = 20000;
const GetTimeout = 10000;
export class HttpClient {
public async fetch(url: string): Promise<any> {
return new Promise<any>((resolve, reject) => {
request.get(url, { timeout: GetTimeout }, (error, response, body) => {
if (error) {
return reject(error);
}
if (response.statusCode === 404) {
return reject(constants.resourceNotFoundError);
}
if (response.statusCode !== 200) {
return reject(
constants.httpGetRequestError(
response.statusCode,
response.statusMessage));
}
resolve(body);
});
});
}
public download(downloadUrl: string, targetPath: string, outputChannel: vscode.OutputChannel): Promise<void> {
return new Promise((resolve, reject) => {
let totalMegaBytes: number | undefined = undefined;
let receivedBytes = 0;
let printThreshold = 0.1;
let downloadRequest = request.get(downloadUrl, { timeout: DownloadTimeout })
.on('error', downloadError => {
outputChannel.appendLine(constants.downloadError);
reject(downloadError);
})
.on('response', (response) => {
if (response.statusCode !== 200) {
outputChannel.appendLine(constants.downloadError);
return reject(response.statusMessage);
}
let contentLength = response.headers['content-length'];
let totalBytes = parseInt(contentLength || '0');
totalMegaBytes = totalBytes / (1024 * 1024);
outputChannel.appendLine(`'Downloading' (0 / ${totalMegaBytes.toFixed(2)} MB)`);
})
.on('data', (data) => {
receivedBytes += data.length;
if (totalMegaBytes) {
let receivedMegaBytes = receivedBytes / (1024 * 1024);
let percentage = receivedMegaBytes / totalMegaBytes;
if (percentage >= printThreshold) {
outputChannel.appendLine(`${constants.downloadingProgress} (${receivedMegaBytes.toFixed(2)} / ${totalMegaBytes.toFixed(2)} MB)`);
printThreshold += 0.1;
}
}
});
downloadRequest.pipe(fs.createWriteStream(targetPath))
.on('close', async () => {
resolve();
})
.on('error', (downloadError) => {
reject(downloadError);
downloadRequest.abort();
});
});
}
}

View File

@@ -0,0 +1,91 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as childProcess from 'child_process';
const ExecScriptsTimeoutInSeconds = 600000;
export class ProcessService {
public timeout = ExecScriptsTimeoutInSeconds;
public async execScripts(exeFilePath: string, scripts: string[], args?: string[], outputChannel?: vscode.OutputChannel): Promise<string> {
return new Promise<string>((resolve, reject) => {
const scriptExecution = childProcess.spawn(exeFilePath, args);
let timer: NodeJS.Timeout;
let output: string = '';
scripts.forEach(script => {
scriptExecution.stdin.write(`${script}\n`);
});
scriptExecution.stdin.end();
// Add listeners to print stdout and stderr if an output channel was provided
scriptExecution.stdout.on('data', data => {
if (outputChannel) {
this.outputDataChunk(data, outputChannel, ' stdout: ');
}
output = output + data.toString();
});
scriptExecution.stderr.on('data', data => {
if (outputChannel) {
this.outputDataChunk(data, outputChannel, ' stderr: ');
}
output = output + data.toString();
});
scriptExecution.on('exit', (code) => {
if (timer) {
clearTimeout(timer);
}
if (code === 0) {
resolve(output);
} else {
reject(`Process exited with code: ${code}. output: ${output}`);
}
});
timer = setTimeout(() => {
try {
scriptExecution.kill();
} catch (error) {
console.log(error);
}
}, this.timeout);
});
}
public async executeBufferedCommand(cmd: string, outputChannel?: vscode.OutputChannel): Promise<string> {
return new Promise<string>((resolve, reject) => {
if (outputChannel) {
outputChannel.appendLine(` > ${cmd}`);
}
let child = childProcess.exec(cmd, {
timeout: this.timeout
}, (err, stdout) => {
if (err) {
reject(err);
} else {
resolve(stdout);
}
});
// Add listeners to print stdout and stderr if an output channel was provided
if (outputChannel) {
child.stdout.on('data', data => { this.outputDataChunk(data, outputChannel, ' stdout: '); });
child.stderr.on('data', data => { this.outputDataChunk(data, outputChannel, ' stderr: '); });
}
});
}
private outputDataChunk(data: string | Buffer, outputChannel: vscode.OutputChannel, header: string): void {
data.toString().split(/\r?\n/)
.forEach(line => {
outputChannel.appendLine(header + line);
});
}
}

View File

@@ -0,0 +1,214 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as nbExtensionApis from '../typings/notebookServices';
import { ApiWrapper } from './apiWrapper';
import * as constants from '../common/constants';
import * as utils from '../common/utils';
const maxNumberOfRetries = 2;
const listPythonPackagesQuery = `
Declare @tablevar table(name NVARCHAR(MAX), version NVARCHAR(MAX))
insert into @tablevar(name, version)
EXEC sp_execute_external_script
@language=N'Python',
@script=N'import pkg_resources
import pandas
OutputDataSet = pandas.DataFrame([(d.project_name, d.version) for d in pkg_resources.working_set])'
select e.name, version from sys.external_libraries e join @tablevar t on e.name = t.name
where [language] = 'PYTHON'
`;
const listRPackagesQuery = `
Declare @tablevar table(name NVARCHAR(MAX), version NVARCHAR(MAX))
insert into @tablevar(name, version)
EXEC sp_execute_external_script
@language=N'R',
@script=N'
OutputDataSet <- as.data.frame(installed.packages()[,c(1,3)])'
select e.name, version from sys.external_libraries e join @tablevar t on e.name = t.name
where [language] = 'R'
`;
const checkMlInstalledQuery = `
Declare @tablevar table(name NVARCHAR(MAX), min INT, max INT, config_value bit, run_value bit)
insert into @tablevar(name, min, max, config_value, run_value) exec sp_configure
Declare @external_script_enabled bit
SELECT @external_script_enabled=config_value FROM @tablevar WHERE name = 'external scripts enabled'
SELECT @external_script_enabled`;
const checkLanguageInstalledQuery = `
SELECT is_installed
FROM sys.dm_db_external_language_stats s, sys.external_languages l
WHERE s.external_language_id = l.external_language_id AND language = '#LANGUAGE#'`;
const modifyExternalScriptConfigQuery = `
EXEC sp_configure 'external scripts enabled', #CONFIG_VALUE#;
RECONFIGURE WITH OVERRIDE;
Declare @tablevar table(name NVARCHAR(MAX), min INT, max INT, config_value bit, run_value bit)
insert into @tablevar(name, min, max, config_value, run_value) exec sp_configure
Declare @external_script_enabled bit
SELECT @external_script_enabled=config_value FROM @tablevar WHERE name = 'external scripts enabled'
SELECT @external_script_enabled`;
/**
* SQL Query runner
*/
export class QueryRunner {
constructor(private _apiWrapper: ApiWrapper) {
}
/**
* Returns python packages installed in SQL server instance
* @param connection SQL Connection
*/
public async getPythonPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
return this.getPackages(connection, databaseName, listPythonPackagesQuery);
}
/**
* Returns python packages installed in SQL server instance
* @param connection SQL Connection
*/
public async getRPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
return this.getPackages(connection, databaseName, listRPackagesQuery);
}
private async getPackages(connection: azdata.connection.ConnectionProfile, databaseName: string, script: string): Promise<nbExtensionApis.IPackageDetails[]> {
let packages: nbExtensionApis.IPackageDetails[] = [];
let result: azdata.SimpleExecuteResult | undefined = undefined;
for (let index = 0; index < maxNumberOfRetries; index++) {
result = await this.runQuery(connection, utils.getScriptWithDBChange(connection.databaseName, databaseName, script));
if (result && result.rowCount > 0) {
break;
}
}
if (result && result.rows.length > 0) {
packages = result.rows.map(row => {
return {
name: row[0].displayValue,
version: row[1].displayValue
};
});
}
return packages;
}
/**
* Updates External Script Config in a SQL server instance
* @param connection SQL Connection
* @param enable if true the config will be enabled otherwise it will be disabled
*/
public async updateExternalScriptConfig(connection: azdata.connection.ConnectionProfile, enable: boolean): Promise<void> {
let query = modifyExternalScriptConfigQuery;
let configValue = enable ? '1' : '0';
query = query.replace('#CONFIG_VALUE#', configValue);
await this.runQuery(connection, query);
}
/**
* Returns true if python installed in the give SQL server instance
*/
public async isPythonInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
return this.isLanguageInstalled(connection, constants.pythonLanguageName);
}
/**
* Returns true if R installed in the give SQL server instance
*/
public async isRInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
return this.isLanguageInstalled(connection, constants.rLanguageName);
}
/**
* Returns true if language installed in the give SQL server instance
*/
private async isLanguageInstalled(connection: azdata.connection.ConnectionProfile, language: string): Promise<boolean> {
let result = await this.runQuery(connection, checkLanguageInstalledQuery.replace('#LANGUAGE#', language));
let isInstalled = false;
if (result && result.rows && result.rows.length > 0) {
isInstalled = result.rows[0][0].displayValue === '1';
}
return isInstalled;
}
/**
* Returns true if mls is installed in the give SQL server instance
*/
public async isMachineLearningServiceEnabled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
let result = await this.runQuery(connection, checkMlInstalledQuery);
let isEnabled = false;
if (result && result.rows && result.rows.length > 0) {
isEnabled = result.rows[0][0].displayValue === '1';
}
return isEnabled;
}
public async runQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
let result: azdata.SimpleExecuteResult | undefined = undefined;
try {
if (connection) {
let connectionUri = await this._apiWrapper.getUriForConnection(connection.connectionId);
let queryProvider = this._apiWrapper.getProvider<azdata.QueryProvider>(connection.providerId, azdata.DataProviderType.QueryProvider);
if (queryProvider) {
result = await queryProvider.runQueryAndReturn(connectionUri, query);
}
}
} catch (error) {
console.log(error);
}
return result;
}
/**
* Executes the query but doesn't fail it is fails
* @param connection SQL connection
* @param query query to run
*/
public async safeRunQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
try {
return await this.runQuery(connection, query);
} catch (error) {
//console.log(error);
return undefined;
}
}
/**
* Executes the query but doesn't fail it is fails
* @param connection SQL connection
* @param query query to run
*/
public async runWithDatabaseChange(connection: azdata.connection.ConnectionProfile, query: string, queryDb: string): Promise<azdata.SimpleExecuteResult | undefined> {
if (connection) {
try {
return await this.runQuery(connection, `
USE [${utils.doubleEscapeSingleBrackets(queryDb)}]
${query}`);
} catch (error) {
console.log(error);
}
finally {
this.safeRunQuery(connection, `USE [${utils.doubleEscapeSingleBrackets(connection.databaseName || 'master')}]`);
}
}
return undefined;
}
}

View File

@@ -0,0 +1,261 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as path from 'path';
import * as os from 'os';
import * as fs from 'fs';
import * as constants from './constants';
import { promisify } from 'util';
import { ApiWrapper } from './apiWrapper';
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
let tempFilePath: string = '';
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, content);
let result = await command(tempFilePath);
return result;
}
finally {
await deleteFile(tempFilePath);
}
}
/**
* Deletes a file
* @param filePath file path
*/
export async function deleteFile(filePath: string) {
if (filePath) {
await fs.promises.unlink(filePath);
}
}
export async function readFileInHex(filePath: string): Promise<string> {
let buffer = await fs.promises.readFile(filePath);
return `0X${buffer.toString('hex')}`;
}
export async function exists(path: string): Promise<boolean> {
return promisify(fs.exists)(path);
}
export async function createFolder(dirPath: string): Promise<void> {
let folderExists = await exists(dirPath);
if (!folderExists) {
await fs.promises.mkdir(dirPath);
}
}
export function getPythonInstallationLocation(rootFolder: string) {
return path.join(rootFolder, 'python');
}
export function getPythonExePath(rootFolder: string): string {
return path.join(
getPythonInstallationLocation(rootFolder),
constants.pythonBundleVersion,
process.platform === constants.winPlatform ? 'python.exe' : 'bin/python3');
}
export function getPackageFilePath(rootFolder: string, packageName: string): string {
return path.join(
rootFolder,
constants.rLPackagedFolderName,
packageName);
}
export function getRPackagesFolderPath(rootFolder: string): string {
return path.join(
rootFolder,
constants.rLPackagedFolderName);
}
/**
* Compares two version strings to see which is greater.
* @param first First version string to compare.
* @param second Second version string to compare.
* @returns 1 if the first version is greater, -1 if it's less, and 0 otherwise.
*/
export function comparePackageVersions(first: string, second: string): number {
let firstVersion = first.split('.').map(numStr => Number.parseInt(numStr));
let secondVersion = second.split('.').map(numStr => Number.parseInt(numStr));
// If versions have different lengths, then append zeroes to the shorter one
if (firstVersion.length > secondVersion.length) {
let diff = firstVersion.length - secondVersion.length;
secondVersion = secondVersion.concat(new Array(diff).fill(0));
} else if (secondVersion.length > firstVersion.length) {
let diff = secondVersion.length - firstVersion.length;
firstVersion = firstVersion.concat(new Array(diff).fill(0));
}
for (let i = 0; i < firstVersion.length; ++i) {
if (firstVersion[i] > secondVersion[i]) {
return 1;
} else if (firstVersion[i] < secondVersion[i]) {
return -1;
}
}
return 0;
}
export function sortPackageVersions(versions: string[], ascending: boolean = true) {
return versions.sort((first, second) => {
let compareResult = comparePackageVersions(first, second);
if (ascending) {
return compareResult;
} else {
return compareResult * -1;
}
});
}
export function isWindows(): boolean {
return process.platform === 'win32';
}
/**
* Escapes all single-quotes (') by prefixing them with another single quote ('')
* ' => ''
* @param value The string to escape
*/
export function doubleEscapeSingleQuotes(value: string | undefined): string {
return value ? value.replace(/'/g, '\'\'') : '';
}
/**
* Escapes all single-bracket ([]) by replacing them with another bracket quote ([[]])
* ' => ''
* @param value The string to escape
*/
export function doubleEscapeSingleBrackets(value: string | undefined): string {
return value ? value.replace(/\[/g, '[[').replace(/\]/g, ']]') : '';
}
/**
* Installs dependencies for the extension
*/
export async function executeTasks<T>(apiWrapper: ApiWrapper, taskName: string, dependencies: PromiseLike<T>[], parallel: boolean): Promise<T[]> {
return new Promise<T[]>((resolve, reject) => {
let msgTaskName = taskName;
apiWrapper.startBackgroundOperation({
displayName: msgTaskName,
description: msgTaskName,
isCancelable: false,
operation: async op => {
try {
let result: T[] = [];
// Install required packages
//
if (parallel) {
result = await Promise.all(dependencies);
} else {
for (let index = 0; index < dependencies.length; index++) {
result.push(await dependencies[index]);
}
}
op.updateStatus(azdata.TaskStatus.Succeeded);
resolve(result);
} catch (error) {
let errorMsg = constants.taskFailedError(taskName, error ? error.message : '');
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
reject(errorMsg);
}
}
});
});
}
export async function promptConfirm(message: string, apiWrapper: ApiWrapper): Promise<boolean> {
let choices: { [id: string]: boolean } = {};
choices[constants.msgYes] = true;
choices[constants.msgNo] = false;
let options = {
placeHolder: message
};
let result = await apiWrapper.showQuickPick(Object.keys(choices).map(c => {
return {
label: c
};
}), options);
if (result === undefined) {
throw Error('invalid selection');
}
return choices[result.label] || false;
}
export function makeLinuxPath(filePath: string): string {
const parts = filePath.split('\\');
return parts.join('/');
}
/**
*
* @param currentDb Wraps the given script with database switch scripts
* @param databaseName
* @param script
*/
export function getScriptWithDBChange(currentDb: string, databaseName: string, script: string): string {
if (!currentDb) {
currentDb = 'master';
}
let escapedDbName = doubleEscapeSingleBrackets(databaseName);
let escapedCurrentDbName = doubleEscapeSingleBrackets(currentDb);
return `
USE [${escapedDbName}]
${script}
USE [${escapedCurrentDbName}]
`;
}
/**
* Returns full name of model registration table
* @param config config
*/
export function getRegisteredModelsThreePartsName(db: string, table: string, schema: string) {
const dbName = doubleEscapeSingleBrackets(db);
const schemaName = doubleEscapeSingleBrackets(schema);
const tableName = doubleEscapeSingleBrackets(table);
return `[${dbName}].[${schemaName}].[${tableName}]`;
}
/**
* Returns full name of model registration table
* @param config config object
*/
export function getRegisteredModelsTwoPartsName(table: string, schema: string) {
const schemaName = doubleEscapeSingleBrackets(schema);
const tableName = doubleEscapeSingleBrackets(table);
return `[${schemaName}].[${tableName}]`;
}
/**
* Write a file using a hex string
* @param content file content
*/
export async function writeFileFromHex(content: string): Promise<string> {
content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content;
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, Buffer.from(content, 'hex'));
return tempFilePath;
}
/**
*
* @param filePath Returns file name
*/
export function getFileName(filePath: string) {
if (filePath) {
return filePath.replace(/^.*[\\\/]/, '');
} else {
return '';
}
}

View File

@@ -0,0 +1,138 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import { ApiWrapper } from '../common/apiWrapper';
import * as constants from '../common/constants';
import { promises as fs } from 'fs';
import * as path from 'path';
import { PackageConfigModel } from './packageConfigModel';
const configFileName = 'config.json';
const defaultPythonExecutable = 'python';
const defaultRExecutable = 'r';
/**
* Extension Configuration from app settings
*/
export class Config {
private _configValues: any;
constructor(private _root: string, private _apiWrapper: ApiWrapper) {
}
/**
* Loads the config values
*/
public async load(): Promise<void> {
const rawConfig = await fs.readFile(path.join(this._root, configFileName));
this._configValues = JSON.parse(rawConfig.toString());
}
/**
* Returns the config value of required python packages
*/
public get requiredSqlPythonPackages(): PackageConfigModel[] {
return this._configValues.sqlPackageManagement.requiredPythonPackages;
}
/**
* Returns the config value of required r packages
*/
public get requiredSqlRPackages(): PackageConfigModel[] {
return this._configValues.sqlPackageManagement.requiredRPackages;
}
/**
* Returns r packages repository
*/
public get rPackagesRepository(): string {
return this._configValues.sqlPackageManagement.rPackagesRepository;
}
/**
* Returns python path from user settings
*/
public get pythonExecutable(): string {
return this.config.get(constants.pythonPathConfigKey) || defaultPythonExecutable;
}
/**
* Returns true if python package management is enabled
*/
public get pythonEnabled(): boolean {
return this.config.get(constants.pythonEnabledConfigKey) || false;
}
/**
* Returns true if r package management is enabled
*/
public get rEnabled(): boolean {
return this.config.get(constants.rEnabledConfigKey) || false;
}
/**
* Returns registered models table name
*/
public get registeredModelTableName(): string {
return this._configValues.modelManagement.registeredModelsTableName;
}
/**
* Returns registered models table schema name
*/
public get registeredModelTableSchemaName(): string {
return this._configValues.modelManagement.registeredModelsTableSchemaName;
}
/**
* Returns registered models table name
*/
public get registeredModelDatabaseName(): string {
return this._configValues.modelManagement.registeredModelsDatabaseName;
}
/**
* Returns Azure ML API
*/
public get amlModelManagementUrl(): string {
return this._configValues.modelManagement.amlModelManagementUrl;
}
/**
* Returns Azure ML API
*/
public get amlExperienceUrl(): string {
return this._configValues.modelManagement.amlExperienceUrl;
}
/**
* Returns Azure ML API Version
*/
public get amlApiVersion(): string {
return this._configValues.modelManagement.amlApiVersion;
}
/**
* Returns model management python packages
*/
public get modelsRequiredPythonPackages(): PackageConfigModel[] {
return this._configValues.modelManagement.requiredPythonPackages;
}
/**
* Returns r path from user settings
*/
public get rExecutable(): string {
return this.config.get(constants.rPathConfigKey) || defaultRExecutable;
}
private get config(): vscode.WorkspaceConfiguration {
return this._apiWrapper.getConfiguration(constants.mlsConfigKey);
}
}

View File

@@ -0,0 +1,35 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
/**
* The model for package config value
*/
export interface PackageConfigModel {
/**
* Package name
*/
name: string;
/**
* Package version
*/
version?: string;
/**
* Package repository
*/
repository?: string;
/**
* Package download url
*/
downloadUrl?: string;
/**
* Package file name if package has download url
*/
fileName?: string;
}

View File

@@ -0,0 +1,165 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as nbExtensionApis from '../typings/notebookServices';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import { ProcessService } from '../common/processService';
import { Config } from '../configurations/config';
import { PackageManagementService } from '../packageManagement/packageManagementService';
import { HttpClient } from '../common/httpClient';
import { ModelManagementController } from '../views/models/modelManagementController';
import { DeployedModelService } from '../modelManagement/deployedModelService';
import { AzureModelRegistryService } from '../modelManagement/azureModelRegistryService';
import { ModelPythonClient } from '../modelManagement/modelPythonClient';
import { PredictService } from '../prediction/predictService';
import { DashboardWidget } from '../views/widgets/dashboardWidget';
import { ModelConfigRecent } from '../modelManagement/modelConfigRecent';
/**
* The main controller class that initializes the extension
*/
export default class MainController implements vscode.Disposable {
private _outputChannel: vscode.OutputChannel;
private _rootPath = this._context.extensionPath;
private _config: Config;
public constructor(
private _context: vscode.ExtensionContext,
private _apiWrapper: ApiWrapper,
private _queryRunner: QueryRunner,
private _processService: ProcessService,
private _packageManager?: PackageManager,
private _packageManagementService?: PackageManagementService,
private _httpClient?: HttpClient
) {
this._outputChannel = this._apiWrapper.createOutputChannel(constants.extensionOutputChannel);
this._rootPath = this._context.extensionPath;
this._config = new Config(this._rootPath, this._apiWrapper);
}
/**
* Deactivates the extension
*/
public deactivate(): void {
}
/**
* Activates the extension
*/
public async activate(): Promise<boolean> {
await this.initialize();
return Promise.resolve(true);
}
/**
* Returns an instance of Server Installation from notebook extension
*/
private async getNotebookExtensionApis(): Promise<nbExtensionApis.IExtensionApi> {
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
if (nbExtension) {
await nbExtension.activate();
return (nbExtension.exports as nbExtensionApis.IExtensionApi);
} else {
throw new Error(constants.notebookExtensionNotLoaded);
}
}
private async initialize(): Promise<void> {
this._outputChannel.show(true);
let nbApis = await this.getNotebookExtensionApis();
await this._config.load();
let packageManager = this.getPackageManager(nbApis);
this._apiWrapper.registerCommand(constants.mlManagePackagesCommand, (async () => {
await packageManager.managePackages();
}));
// External Languages
//
let modelImporter = new ModelPythonClient(this._outputChannel, this._apiWrapper, this._processService, this._config, packageManager);
let modelRecentService = new ModelConfigRecent(this._context.globalState);
// Model Management
//
let registeredModelService = new DeployedModelService(this._apiWrapper, this._config, this._queryRunner, modelImporter, modelRecentService);
let azureModelsService = new AzureModelRegistryService(this._apiWrapper, this._config, this.httpClient, this._outputChannel);
let predictService = new PredictService(this._apiWrapper, this._queryRunner);
let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath,
azureModelsService, registeredModelService, predictService);
let dashboardWidget = new DashboardWidget(this._apiWrapper, this._rootPath);
dashboardWidget.register();
this._apiWrapper.registerCommand(constants.mlManageModelsCommand, (async () => {
await modelManagementController.manageRegisteredModels();
}));
this._apiWrapper.registerCommand(constants.mlImportModelCommand, (async () => {
await modelManagementController.registerModel(undefined);
}));
this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => {
await modelManagementController.predictModel();
}));
this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => {
await packageManager.installDependencies();
}));
this._apiWrapper.registerTaskHandler(constants.mlManagePackagesCommand, async () => {
await packageManager.managePackages();
});
}
/**
* Returns the package manager instance
*/
public getPackageManager(nbApis: nbExtensionApis.IExtensionApi): PackageManager {
if (!this._packageManager) {
this._packageManager = new PackageManager(this._outputChannel, this._rootPath, this._apiWrapper, this.packageManagementService, this._processService, this._config, this.httpClient);
this._packageManager.init();
this._packageManager.packageManageProviders.forEach(provider => {
nbApis.registerPackageManager(provider.providerId, provider);
});
}
return this._packageManager;
}
/**
* Returns the server config manager instance
*/
public get packageManagementService(): PackageManagementService {
if (!this._packageManagementService) {
this._packageManagementService = new PackageManagementService(this._apiWrapper, this._queryRunner);
}
return this._packageManagementService;
}
/**
* Returns the server config manager instance
*/
public get httpClient(): HttpClient {
if (!this._httpClient) {
this._httpClient = new HttpClient();
}
return this._httpClient;
}
/**
* Config instance
*/
public get config(): Config {
return this._config;
}
/**
* Disposes the extension
*/
public dispose(): void {
this.deactivate();
}
}

View File

@@ -0,0 +1,59 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as mssql from '../../../mssql';
import { ApiWrapper } from '../common/apiWrapper';
/**
* Manage package dialog model
*/
export class LanguageService {
public connection: azdata.connection.ConnectionProfile | undefined;
public connectionUrl: string = '';
constructor(
private _apiWrapper: ApiWrapper,
private _languageExtensionService: mssql.ILanguageExtensionService) {
}
public async load() {
this.connection = await this.getCurrentConnection();
this.connectionUrl = await this.getCurrentConnectionUrl();
}
public async getLanguageList(): Promise<mssql.ExternalLanguage[]> {
if (this.connectionUrl) {
return await this._languageExtensionService.listLanguages(this.connectionUrl);
}
return [];
}
public async deleteLanguage(languageName: string): Promise<void> {
if (this.connectionUrl) {
await this._languageExtensionService.deleteLanguage(this.connectionUrl, languageName);
}
}
public async updateLanguage(language: mssql.ExternalLanguage): Promise<void> {
if (this.connectionUrl) {
await this._languageExtensionService.updateLanguage(this.connectionUrl, language);
}
}
private async getCurrentConnectionUrl(): Promise<string> {
let connection = await this.getCurrentConnection();
if (connection) {
return await this._apiWrapper.getUriForConnection(connection.connectionId);
}
return '';
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
}

View File

@@ -0,0 +1,33 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import MainController from './controllers/mainController';
import { ApiWrapper } from './common/apiWrapper';
import { QueryRunner } from './common/queryRunner';
import { ProcessService } from './common/processService';
let controllers: MainController[] = [];
export async function activate(context: vscode.ExtensionContext): Promise<void> {
let apiWrapper = new ApiWrapper();
let queryRunner = new QueryRunner(apiWrapper);
let processService = new ProcessService();
// Start the main controller
//
let mainController = new MainController(context, apiWrapper, queryRunner, processService);
controllers.push(mainController);
context.subscriptions.push(mainController);
await mainController.activate();
}
export function deactivate(): void {
for (let controller of controllers) {
controller.deactivate();
}
}

View File

@@ -0,0 +1,62 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import * as Models from './interfaces';
import * as Mappers from './mappers';
import * as Parameters from './parameters';
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
export class Artifacts {
private readonly client: AzureMachineLearningWorkspacesContext;
constructor(client: AzureMachineLearningWorkspacesContext) {
this.client = client;
}
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options?: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams): Promise<Models.GetArtifactContentInformation2Response>;
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, callback: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): void;
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams, callback: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): void;
getArtifactContentInformation2(subscriptionId: string, resourceGroupName: string, workspaceName: string, origin: string, container: string, options?: Models.ArtifactAPIGetArtifactContentInformation2OptionalParams | msRest.ServiceCallback<Models.ArtifactContentInformationDto>, callback?: msRest.ServiceCallback<Models.ArtifactContentInformationDto>): Promise<Models.GetArtifactContentInformation2Response> {
return this.client.sendOperationRequest(
{
subscriptionId,
resourceGroupName,
workspaceName,
origin,
container,
options
},
getArtifactContentInformation2OperationSpec,
callback) as Promise<Models.GetArtifactContentInformation2Response>;
}
}
const serializer = new msRest.Serializer(Mappers);
const getArtifactContentInformation2OperationSpec: msRest.OperationSpec = {
httpMethod: 'GET',
path: 'artifact/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}/artifacts/contentinfo/{origin}/{container}',
urlParameters: [
Parameters.subscriptionId,
Parameters.resourceGroupName,
Parameters.workspaceName,
Parameters.origin,
Parameters.container,
Parameters.apiVersion
],
queryParameters: [
Parameters.projectName0,
Parameters.path1,
Parameters.accountName
],
responses: {
200: {
bodyMapper: Mappers.ArtifactContentInformationDto
},
default: {}
},
serializer
};

View File

@@ -0,0 +1,78 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import * as Models from './interfaces';
import * as Mappers from './mappers';
import * as Parameters from './parameters';
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
export class Assets {
private readonly client: AzureMachineLearningWorkspacesContext;
constructor(client: AzureMachineLearningWorkspacesContext) {
this.client = client;
}
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
options?: msRest.RequestOptionsBase
): Promise<Models.AssetsQueryByIdResponse>;
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
callback: msRest.ServiceCallback<Models.Asset>
): void;
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
options: msRest.RequestOptionsBase,
callback: msRest.ServiceCallback<Models.Asset>
): void;
queryById(
subscriptionId: string,
resourceGroup: string,
workspace: string,
id: string,
options?: msRest.RequestOptionsBase | msRest.ServiceCallback<Models.Asset>,
callback?: msRest.ServiceCallback<Models.Asset>
): Promise<Models.AssetsQueryByIdResponse> {
return this.client.sendOperationRequest(
{
subscriptionId,
resourceGroup,
workspace,
id,
options
},
queryByIdOperationSpec,
callback
) as Promise<Models.AssetsQueryByIdResponse>;
}
}
const serializer = new msRest.Serializer(Mappers);
const queryByIdOperationSpec: msRest.OperationSpec = {
httpMethod: 'GET',
path:
'modelmanagement/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroup}/providers/Microsoft.MachineLearningServices/workspaces/{workspace}/assets/{id}',
urlParameters: [Parameters.subscriptionId, Parameters.resourceGroup, Parameters.workspace, Parameters.id],
responses: {
200: {
bodyMapper: Mappers.Asset
},
default: {
bodyMapper: Mappers.ModelErrorResponse
}
},
serializer
};

View File

@@ -0,0 +1,329 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../common/apiWrapper';
import * as constants from '../common/constants';
import { azureResource } from '../typings/azure-resource';
import { AzureMachineLearningWorkspaces } from '@azure/arm-machinelearningservices';
import { TokenCredentials } from '@azure/ms-rest-js';
import { WorkspaceModels } from './workspacesModels';
import { AzureMachineLearningWorkspacesOptions, Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel, Asset, IArtifactParts } from './interfaces';
import { Config } from '../configurations/config';
import { Assets } from './assets';
import * as polly from 'polly-js';
import { Artifacts } from './artifacts';
import { HttpClient } from '../common/httpClient';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as path from 'path';
import * as os from 'os';
import * as utils from '../common/utils';
/**
* Azure Model Service
*/
export class AzureModelRegistryService {
private _amlClient: AzureMachineLearningWorkspaces | undefined;
private _modelClient: WorkspaceModels | undefined;
/**
* Creates new service
*/
constructor(
private _apiWrapper: ApiWrapper,
private _config: Config,
private _httpClient: HttpClient,
private _outputChannel: vscode.OutputChannel) {
}
/**
* Returns list of azure accounts
*/
public async getAccounts(): Promise<azdata.Account[]> {
return await this._apiWrapper.getAllAccounts();
}
/**
* Returns list of azure subscriptions
* @param account azure account
*/
public async getSubscriptions(account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[] | undefined> {
const data = <azureResource.GetSubscriptionsResult>await this._apiWrapper.executeCommand(constants.azureSubscriptionsCommand, account, true);
return data?.subscriptions;
}
/**
* Returns list of azure groups
* @param account azure account
* @param subscription azure subscription
*/
public async getGroups(
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[] | undefined> {
const data = <azureResource.GetResourceGroupsResult>await this._apiWrapper.executeCommand(constants.azureResourceGroupsCommand, account, subscription, true);
return data?.resourceGroups;
}
/**
* Returns list of workspaces
* @param account azure account
* @param subscription azure subscription
* @param resourceGroup azure resource group
*/
public async getWorkspaces(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
return await this.fetchWorkspaces(account, subscription, resourceGroup);
}
/**
* Returns list of models
* @param account azure account
* @param subscription azure subscription
* @param resourceGroup azure resource group
* @param workspace azure workspace
*/
public async getModels(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace): Promise<WorkspaceModel[] | undefined> {
return await this.fetchModels(account, subscription, resourceGroup, workspace);
}
/**
* Download an azure model to a temporary location
* @param account azure account
* @param subscription azure subscription
* @param resourceGroup azure resource group
* @param workspace azure workspace
* @param model azure model
*/
public async downloadModel(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel): Promise<string> {
let downloadedFilePath: string = '';
for (const tenant of account.properties.tenants) {
try {
const downloadUrls = await this.getAssetArtifactsDownloadLinks(account, subscription, resourceGroup, workspace, model, tenant);
if (downloadUrls && downloadUrls.length > 0) {
downloadedFilePath = await this.execDownloadArtifactTask(downloadUrls[0]);
}
} catch (error) {
console.log(error);
}
}
return downloadedFilePath;
}
public set AzureMachineLearningClient(value: AzureMachineLearningWorkspaces) {
this._amlClient = value;
}
public set ModelClient(value: WorkspaceModels) {
this._modelClient = value;
}
public async signInToAzure(): Promise<void> {
await this._apiWrapper.executeCommand(constants.signInToAzureCommand);
}
/**
* Execute the background task to download the artifact
*/
private async execDownloadArtifactTask(downloadUrl: string): Promise<string> {
let results = await utils.executeTasks(this._apiWrapper, constants.downloadModelMsgTaskName, [this.downloadArtifact(downloadUrl)], true);
return results && results.length > 0 ? results[0] : constants.noResultError;
}
private async downloadArtifact(downloadUrl: string): Promise<string> {
let tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await this._httpClient.download(downloadUrl, tempFilePath, this._outputChannel);
return tempFilePath;
}
private async fetchWorkspaces(account: azdata.Account, subscription: azureResource.AzureResourceSubscription, resourceGroup: azureResource.AzureResource | undefined): Promise<Workspace[]> {
let resources: Workspace[] = [];
try {
for (const tenant of account.properties.tenants) {
const client = await this.getAmlClient(account, subscription, tenant);
let result = resourceGroup ? await client.workspaces.listByResourceGroup(resourceGroup.name) : await client.workspaces.listBySubscription();
if (result) {
resources.push(...result);
}
}
} catch (error) {
console.log(error);
}
return resources;
}
private async fetchModels(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace): Promise<WorkspaceModel[]> {
let resources: WorkspaceModel[] = [];
for (const tenant of account.properties.tenants) {
try {
let options: AzureMachineLearningWorkspacesOptions = {
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
};
const client = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
let modelsClient = this.getModelClient(client);
resources = resources.concat(await modelsClient.listModels(resourceGroup.name, workspace.name || ''));
} catch (error) {
console.log(error);
}
}
return resources;
}
private async fetchModelAsset(
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel,
client: AzureMachineLearningWorkspaces): Promise<Asset> {
const modelId = this.getModelId(model);
if (modelId) {
let modelsClient = new Assets(client);
return await modelsClient.queryById(subscription.id, resourceGroup.name, workspace.name || '', modelId);
} else {
throw Error(constants.invalidModelIdError(model.url));
}
}
private async getAssetArtifactsDownloadLinks(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
resourceGroup: azureResource.AzureResource,
workspace: Workspace,
model: WorkspaceModel,
tenant: any): Promise<string[]> {
let options: AzureMachineLearningWorkspacesOptions = {
baseUri: this.getBaseUrl(workspace, this._config.amlModelManagementUrl)
};
const modelManagementClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
const asset = await this.fetchModelAsset(subscription, resourceGroup, workspace, model, modelManagementClient);
options.baseUri = this.getBaseUrl(workspace, this._config.amlExperienceUrl);
const experienceClient = await this.getAmlClient(account, subscription, tenant, options, this._config.amlApiVersion);
const artifactClient = new Artifacts(experienceClient);
let downloadLinks: string[] = [];
if (asset && asset.artifacts) {
const downloadLinkPromises: Array<Promise<string>> = [];
for (const artifact of asset.artifacts) {
const parts = artifact.id
? this.getPartsFromAssetIdOrPrefix(artifact.id)
: this.getPartsFromAssetIdOrPrefix(artifact.prefix);
if (parts) {
const promise = polly()
.waitAndRetry(3)
.executeForPromise(
async (): Promise<string> => {
const artifact = await artifactClient.getArtifactContentInformation2(
experienceClient.subscriptionId,
resourceGroup.name,
workspace.name || '',
parts.origin,
parts.container,
{ path: parts.path }
);
if (artifact) {
return artifact.contentUri || '';
} else {
return Promise.reject();
}
}
);
downloadLinkPromises.push(promise);
}
}
try {
downloadLinks = await Promise.all(downloadLinkPromises);
} catch (rejectedPromiseError) {
return rejectedPromiseError;
}
return downloadLinks;
} else {
throw Error(constants.noArtifactError(model.url));
}
}
private getPartsFromAssetIdOrPrefix(idOrPrefix: string | undefined): IArtifactParts | undefined {
const artifactRegex = /^(.+?)\/(.+?)\/(.+?)$/;
if (idOrPrefix) {
const parts = artifactRegex.exec(idOrPrefix);
if (parts && parts.length === 4) {
return {
origin: parts[1],
container: parts[2],
path: parts[3]
};
}
}
return undefined;
}
private getBaseUrl(workspace: Workspace, server: string): string {
let baseUri = `https://${workspace.location}.${server}`;
if (workspace.location === 'chinaeast2') {
baseUri = `https://${workspace.location}.${server}`;
}
return baseUri;
}
private getModelClient(amlClient: AzureMachineLearningWorkspaces) {
return this._modelClient ?? new WorkspaceModels(amlClient);
}
private async getAmlClient(
account: azdata.Account,
subscription: azureResource.AzureResourceSubscription,
tenant: any,
options: AzureMachineLearningWorkspacesOptions | undefined = undefined,
apiVersion: string | undefined = undefined): Promise<AzureMachineLearningWorkspaces> {
if (this._amlClient) {
return this._amlClient;
} else {
const tokens = await this._apiWrapper.getSecurityToken(account, azdata.AzureResource.ResourceManagement);
let token: string = '';
let tokenType: string | undefined = undefined;
if (tokens && tenant.id in tokens) {
const tokenForId = tokens[tenant.id];
if (tokenForId) {
token = tokenForId.token;
tokenType = tokenForId.tokenType;
}
}
const client = new AzureMachineLearningWorkspaces(new TokenCredentials(token, tokenType), subscription.id, options);
if (apiVersion) {
client.apiVersion = apiVersion;
}
return client;
}
}
private getModelId(model: WorkspaceModel): string {
const amlAssetRegex = /^aml:\/\/asset\/(.+)$/;
const id = model ? amlAssetRegex.exec(model.url || '') : undefined;
return id && id.length === 2 ? id[1] : '';
}
}

View File

@@ -0,0 +1,219 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import { Config } from '../configurations/config';
import { QueryRunner } from '../common/queryRunner';
import { ImportedModel, ImportedModelDetails, ModelParameters } from './interfaces';
import { ModelPythonClient } from './modelPythonClient';
import * as constants from '../common/constants';
import * as queries from './queries';
import { DatabaseTable } from '../prediction/interfaces';
import { ModelConfigRecent } from './modelConfigRecent';
/**
* Service to deployed models
*/
export class DeployedModelService {
/**
* Creates new instance
*/
constructor(
private _apiWrapper: ApiWrapper,
private _config: Config,
private _queryRunner: QueryRunner,
private _modelClient: ModelPythonClient,
private _recentModelService: ModelConfigRecent) {
}
/**
* Returns deployed models
*/
public async getDeployedModels(table: DatabaseTable): Promise<ImportedModel[]> {
let connection = await this.getCurrentConnection();
let list: ImportedModel[] = [];
if (!table.databaseName || !table.tableName || !table.schema) {
return [];
}
if (connection) {
const query = queries.getDeployedModelsQuery(table);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push(this.loadModelData(row, table));
});
}
} else {
throw Error(constants.noConnectionError);
}
return list;
}
/**
* Downloads model
* @param model model object
*/
public async downloadModel(model: ImportedModel): Promise<string> {
let connection = await this.getCurrentConnection();
if (connection) {
const query = queries.getModelContentQuery(model);
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
const content = result.rows[0][0].displayValue;
return await utils.writeFileFromHex(content);
} else {
throw Error(constants.invalidModelToSelectError);
}
} else {
throw Error(constants.noConnectionError);
}
}
/**
* Loads model parameters
*/
public async loadModelParameters(filePath: string): Promise<ModelParameters> {
return await this._modelClient.loadModelParameters(filePath);
}
/**
* Deploys local model
* @param filePath model file path
* @param details model details
*/
public async deployLocalModel(filePath: string, details: ImportedModelDetails | undefined, table: DatabaseTable) {
let connection = await this.getCurrentConnection();
if (connection && table.databaseName) {
await this.configureImport(connection, table);
let currentModels = await this.getDeployedModels(table);
const content = await utils.readFileInHex(filePath);
let modelToAdd: ImportedModel = Object.assign({}, {
id: 0,
content: content,
table: table
}, details);
await this._queryRunner.runWithDatabaseChange(connection, queries.getInsertModelQuery(modelToAdd, table), table.databaseName);
let updatedModels = await this.getDeployedModels(table);
if (updatedModels.length < currentModels.length + 1) {
throw Error(constants.importModelFailedError(details?.modelName, filePath));
}
} else {
throw new Error(constants.noConnectionError);
}
}
/**
* Updates a model
*/
public async updateModel(model: ImportedModel) {
let connection = await this.getCurrentConnection();
if (connection && model && model.table && model.table.databaseName) {
await this._queryRunner.runWithDatabaseChange(connection, queries.getUpdateModelQuery(model), model.table.databaseName);
} else {
throw new Error(constants.noConnectionError);
}
}
/**
* Updates a model
*/
public async deleteModel(model: ImportedModel) {
let connection = await this.getCurrentConnection();
if (connection && model && model.table && model.table.databaseName) {
await this._queryRunner.runWithDatabaseChange(connection, queries.getDeleteModelQuery(model), model.table.databaseName);
} else {
throw new Error(constants.noConnectionError);
}
}
public async configureImport(connection: azdata.connection.ConnectionProfile, table: DatabaseTable) {
if (connection && table.databaseName) {
let query = queries.getDatabaseConfigureQuery(table);
await this._queryRunner.safeRunQuery(connection, query);
query = queries.getConfigureTableQuery(table);
await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
}
}
/**
* Verifies if the given table name is valid to be used as import table. If table doesn't exist returns true to create new table
* Otherwise verifies the schema and returns true if the schema is supported
* @param connection database connection
* @param table config table name
*/
public async verifyConfigTable(table: DatabaseTable): Promise<boolean> {
let connection = await this.getCurrentConnection();
if (connection && table.databaseName) {
let databases = await this._apiWrapper.listDatabases(connection.connectionId);
// If database exist verify the table schema
//
if ((await databases).find(x => x === table.databaseName)) {
const query = queries.getConfigTableVerificationQuery(table);
const result = await this._queryRunner.runWithDatabaseChange(connection, query, table.databaseName);
return result !== undefined && result.rows.length > 0 && result.rows[0][0].displayValue === '1';
} else {
return true;
}
} else {
throw new Error(constants.noConnectionError);
}
}
public async getRecentImportTable(): Promise<DatabaseTable> {
let connection = await this.getCurrentConnection();
let table: DatabaseTable | undefined;
if (connection) {
table = this._recentModelService.getModelTable(connection);
if (!table) {
table = {
databaseName: connection.databaseName ?? 'master',
tableName: this._config.registeredModelTableName,
schema: this._config.registeredModelTableSchemaName
};
}
} else {
throw new Error(constants.noConnectionError);
}
return table;
}
public async storeRecentImportTable(importTable: DatabaseTable): Promise<void> {
let connection = await this.getCurrentConnection();
if (connection) {
this._recentModelService.storeModelTable(connection, importTable);
} else {
throw new Error(constants.noConnectionError);
}
}
private loadModelData(row: azdata.DbCellValue[], table: DatabaseTable): ImportedModel {
return {
id: +row[0].displayValue,
modelName: row[1].displayValue,
description: row[2].displayValue,
version: row[3].displayValue,
created: row[4].displayValue,
framework: row[5].displayValue,
frameworkVersion: row[6].displayValue,
deploymentTime: row[7].displayValue,
deployedBy: row[8].displayValue,
runId: row[9].displayValue,
table: table
};
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
}

View File

@@ -0,0 +1,241 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import { Resource } from '@azure/arm-machinelearningservices/esm/models';
import { DatabaseTable } from '../prediction/interfaces';
/**
* An interface representing ListWorkspaceModelResult.
*/
export interface ListWorkspaceModelsResult extends Array<WorkspaceModel> {
}
/**
* An interface representing Workspace model
*/
export interface WorkspaceModel extends Resource {
framework?: string;
frameworkVersion?: string;
createdBy?: string;
createdTime?: string;
experimentName?: string;
outputsSchema?: Array<string>;
url?: string;
}
/**
* An interface representing Workspace model list response
*/
export type WorkspacesModelsResponse = ListWorkspaceModelsResult & {
/**
* The underlying HTTP response.
*/
_response: msRest.HttpResponse & {
/**
* The response body as text (string format)
*/
bodyAsText: string;
/**
* The response body as parsed JSON or XML
*/
parsedBody: ListWorkspaceModelsResult;
};
};
/**
* An interface representing imported model
*/
export interface ImportedModel extends ImportedModelDetails {
id: number;
content?: string;
table: DatabaseTable;
}
export interface ModelParameter {
name: string;
type: string;
}
export interface ModelParameters {
inputs: ModelParameter[],
outputs: ModelParameter[]
}
/**
* An interface representing imported model
*/
export interface ImportedModelDetails {
modelName: string;
created?: string;
deploymentTime?: string;
version?: string;
description?: string;
fileName?: string;
framework?: string;
frameworkVersion?: string;
runId?: string;
deployedBy?: string;
}
/**
* The Artifact definition.
*/
export interface ArtifactDetails {
/**
* The Artifact Id.
*/
id?: string;
/**
* The Artifact prefix.
*/
prefix?: string;
}
/**
* @interface
* An interface representing Asset.
* The Asset definition.
*
*/
export interface Asset {
/**
* @member {string} [id] The Asset Id.
*/
id?: string;
/**
* @member {string} [name] The name of the Asset.
*/
name?: string;
/**
* @member {string} [description] The Asset description.
*/
description?: string;
/**
* @member {ArtifactDetails[]} [artifacts] A list of child artifacts.
*/
artifacts?: ArtifactDetails[];
/**
* @member {string[]} [tags] The Asset tag list.
*/
tags?: string[];
/**
* @member {{ [propertyName: string]: string }} [kvTags] The Asset tag
* dictionary. Tags are mutable.
*/
kvTags?: { [propertyName: string]: string };
/**
* @member {{ [propertyName: string]: string }} [properties] The Asset
* property dictionary. Properties are immutable.
*/
properties?: { [propertyName: string]: string };
/**
* @member {string} [runid] The RunId associated with this Asset.
*/
runid?: string;
/**
* @member {string} [projectid] The project Id.
*/
projectid?: string;
/**
* @member {{ [propertyName: string]: string }} [meta] A dictionary
* containing metadata about the Asset.
*/
meta?: { [propertyName: string]: string };
/**
* @member {Date} [createdTime] The time the Asset was created in UTC.
*/
createdTime?: Date;
}
/**
* Contains response data for the queryById operation.
*/
export type AssetsQueryByIdResponse = Asset & {
/**
* The underlying HTTP response.
*/
_response: msRest.HttpResponse & {
/**
* The response body as text (string format)
*/
bodyAsText: string;
/**
* The response body as parsed JSON or XML
*/
parsedBody: Asset;
};
};
export interface IArtifactParts {
origin: string;
container: string;
path: string;
}
/**
* @interface
* An interface representing ArtifactContentInformationDto.
*/
export interface ArtifactContentInformationDto {
/**
* @member {string} [contentUri]
*/
contentUri?: string;
/**
* @member {string} [origin]
*/
origin?: string;
/**
* @member {string} [container]
*/
container?: string;
/**
* @member {string} [path]
*/
path?: string;
}
/**
* Contains response data for the getArtifactContentInformation2 operation.
*/
export type GetArtifactContentInformation2Response = ArtifactContentInformationDto & {
/**
* The underlying HTTP response.
*/
_response: msRest.HttpResponse & {
/**
* The response body as text (string format)
*/
bodyAsText: string;
/**
* The response body as parsed JSON or XML
*/
parsedBody: ArtifactContentInformationDto;
};
};
/**
* @interface
* An interface representing ArtifactAPIGetArtifactContentInformation2OptionalParams.
* Optional Parameters.
*
* @extends RequestOptionsBase
*/
export interface ArtifactAPIGetArtifactContentInformation2OptionalParams extends msRest.RequestOptionsBase {
/**
* @member {string} [projectName]
*/
projectName?: string;
/**
* @member {string} [path]
*/
path?: string;
/**
* @member {string} [accountName]
*/
accountName?: string;
}

View File

@@ -0,0 +1,320 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
export const Resource: msRest.CompositeMapper = {
serializedName: 'Resource',
type: {
name: 'Composite',
className: 'Resource',
modelProperties: {
id: {
readOnly: true,
serializedName: 'id',
type: {
name: 'String'
}
},
name: {
readOnly: true,
serializedName: 'name',
type: {
name: 'String'
}
},
identity: {
readOnly: true,
serializedName: 'identity',
type: {
name: 'Composite',
className: 'Identity'
}
},
location: {
serializedName: 'location',
type: {
name: 'String'
}
},
type: {
readOnly: true,
serializedName: 'type',
type: {
name: 'String'
}
},
tags: {
serializedName: 'tags',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
}
}
}
};
export const ListWorkspaceModelsResult: msRest.CompositeMapper = {
serializedName: 'ListWorkspaceModelsResult',
type: {
name: 'Composite',
className: 'ListWorkspaceModelsResult',
modelProperties: {
value: {
serializedName: '',
type: {
name: 'Sequence',
element: {
type: {
name: 'Composite',
className: 'WorkspaceModel'
}
}
}
},
nextLink: {
serializedName: 'nextLink',
type: {
name: 'String'
}
}
}
}
};
export const WorkspaceModel: msRest.CompositeMapper = {
serializedName: 'WorkspaceModel',
type: {
name: 'Composite',
className: 'WorkspaceModel',
modelProperties: {
...Resource.type.modelProperties,
framework: {
readOnly: true,
serializedName: 'framework',
type: {
name: 'String'
}
},
}
}
};
export const MachineLearningServiceError: msRest.CompositeMapper = {
serializedName: 'MachineLearningServiceError',
type: {
name: 'Composite',
className: 'MachineLearningServiceError',
modelProperties: {
error: {
readOnly: true,
serializedName: 'error',
type: {
name: 'Composite',
className: 'ErrorResponse'
}
}
}
}
};
export const ModelErrorResponse: msRest.CompositeMapper = {
serializedName: 'ModelErrorResponse',
type: {
name: 'Composite',
className: 'ModelErrorResponse',
modelProperties: {
code: {
serializedName: 'code',
type: {
name: 'String'
}
},
statusCode: {
serializedName: 'statusCode',
type: {
name: 'Number'
}
},
message: {
serializedName: 'message',
type: {
name: 'String'
}
},
details: {
serializedName: 'details',
type: {
name: 'Sequence',
element: {
type: {
name: 'Composite',
className: 'ErrorDetails'
}
}
}
}
}
}
};
export const ArtifactDetails: msRest.CompositeMapper = {
serializedName: 'ArtifactDetails',
type: {
name: 'Composite',
className: 'ArtifactDetails',
modelProperties: {
id: {
serializedName: 'id',
type: {
name: 'String'
}
},
prefix: {
serializedName: 'prefix',
type: {
name: 'String'
}
}
}
}
};
export const Asset: msRest.CompositeMapper = {
serializedName: 'Asset',
type: {
name: 'Composite',
className: 'Asset',
modelProperties: {
id: {
serializedName: 'id',
type: {
name: 'String'
}
},
name: {
serializedName: 'name',
type: {
name: 'String'
}
},
description: {
serializedName: 'description',
type: {
name: 'String'
}
},
artifacts: {
serializedName: 'artifacts',
type: {
name: 'Sequence',
element: {
type: {
name: 'Composite',
className: 'ArtifactDetails'
}
}
}
},
tags: {
serializedName: 'tags',
type: {
name: 'Sequence',
element: {
type: {
name: 'String'
}
}
}
},
kvTags: {
serializedName: 'kvTags',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
},
properties: {
serializedName: 'properties',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
},
runid: {
serializedName: 'runid',
type: {
name: 'String'
}
},
projectid: {
serializedName: 'projectid',
type: {
name: 'String'
}
},
meta: {
serializedName: 'meta',
type: {
name: 'Dictionary',
value: {
type: {
name: 'String'
}
}
}
},
createdTime: {
serializedName: 'createdTime',
type: {
name: 'DateTime'
}
}
}
}
};
export const ArtifactContentInformationDto: msRest.CompositeMapper = {
serializedName: 'ArtifactContentInformationDto',
type: {
name: 'Composite',
className: 'ArtifactContentInformationDto',
modelProperties: {
contentUri: {
serializedName: 'contentUri',
type: {
name: 'String'
}
},
origin: {
serializedName: 'origin',
type: {
name: 'String'
}
},
container: {
serializedName: 'container',
type: {
name: 'String'
}
},
path: {
serializedName: 'path',
type: {
name: 'String'
}
}
}
}
};

View File

@@ -0,0 +1,30 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import { DatabaseTable } from '../prediction/interfaces';
const TableConfigName = 'MLS_ModelTableConfigName';
export class ModelConfigRecent {
/**
*
*/
constructor(private _memento: vscode.Memento) {
}
public getModelTable(connection: azdata.connection.ConnectionProfile): DatabaseTable | undefined {
return this._memento.get<DatabaseTable>(this.getKey(connection));
}
public storeModelTable(connection: azdata.connection.ConnectionProfile, table: DatabaseTable): void {
this._memento.update(this.getKey(connection), table);
}
private getKey(connection: azdata.connection.ConnectionProfile): string {
return `${TableConfigName}_${connection.serverName}`;
}
}

View File

@@ -0,0 +1,128 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ProcessService } from '../common/processService';
import { Config } from '../configurations/config';
import { ApiWrapper } from '../common/apiWrapper';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as utils from '../common/utils';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
import * as os from 'os';
import { ModelParameters } from './interfaces';
/**
* Python client for ONNX models
*/
export class ModelPythonClient {
/**
* Creates new instance
*/
constructor(private _outputChannel: vscode.OutputChannel, private _apiWrapper: ApiWrapper, private _processService: ProcessService, private _config: Config, private _packageManager: PackageManager) {
}
/**
* Deploys models in the SQL database using mlflow
* @param connection
* @param modelPath
*/
public async deployModel(connection: azdata.connection.ConnectionProfile, modelPath: string): Promise<void> {
await this.installDependencies();
await this.executeDeployScripts(connection, modelPath);
}
/**
* Installs dependencies for python client
*/
private async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [
this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true);
}
/**
*
* @param modelPath Loads model parameters
*/
public async loadModelParameters(modelPath: string): Promise<ModelParameters> {
await this.installDependencies();
return await this.executeModelParametersScripts(modelPath);
}
private async executeModelParametersScripts(modelFolderPath: string): Promise<ModelParameters> {
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
let scripts: string[] = [
'import onnx',
'import json',
`onnx_model_path = '${modelFolderPath}'`,
`onnx_model = onnx.load_model(onnx_model_path)`,
`type_map = {
onnx.TensorProto.DataType.FLOAT: 'real',
onnx.TensorProto.DataType.UINT8: 'tinyint',
onnx.TensorProto.DataType.INT16: 'smallint',
onnx.TensorProto.DataType.INT32: 'int',
onnx.TensorProto.DataType.INT64: 'bigint',
onnx.TensorProto.DataType.STRING: 'varchar(MAX)',
onnx.TensorProto.DataType.DOUBLE: 'float'}`,
`parameters = {
"inputs": [],
"outputs": []
}`,
`def addParameters(list, paramType):
for id, p in enumerate(list):
p_type = ''
if p.type.tensor_type.elem_type in type_map:
p_type = type_map[p.type.tensor_type.elem_type]
parameters[paramType].append({
'name': p.name,
'type': p_type
})`,
'addParameters(onnx_model.graph.input, "inputs")',
'addParameters(onnx_model.graph.output, "outputs")',
'print(json.dumps(parameters))'
];
let pythonExecutable = this._config.pythonExecutable;
let output = await this._processService.execScripts(pythonExecutable, scripts, [], undefined);
let parametersJson = JSON.parse(output);
return Object.assign({}, parametersJson);
}
private async executeDeployScripts(connection: azdata.connection.ConnectionProfile, modelFolderPath: string): Promise<void> {
let home = utils.makeLinuxPath(os.homedir());
modelFolderPath = utils.makeLinuxPath(modelFolderPath);
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
if (connection) {
let server = connection.serverName;
const experimentId = `ads_ml_experiment_${UUID.generateUuid()}`;
const credential = connection.userName ? `${connection.userName}:${credentials[azdata.ConnectionOptionSpecialType.password]}@` : '';
let scripts: string[] = [
'import mlflow.onnx',
`tracking_uri = "file://${home}/mlruns"`,
'print(tracking_uri)',
'import onnx',
'from mlflow.tracking.client import MlflowClient',
`onx = onnx.load("${modelFolderPath}")`,
`mlflow.set_tracking_uri(tracking_uri)`,
'client = MlflowClient()',
`exp_name = "${experimentId}"`,
`db_uri_artifact = "mssql+pyodbc://${credential}${server}/MlFlowDB?driver=ODBC+Driver+17+for+SQL+Server&"`,
'client.create_experiment(exp_name, artifact_location=db_uri_artifact)',
'mlflow.set_experiment(exp_name)',
'mlflow.onnx.log_model(onx, "pipeline_vectorize")'
];
let pythonExecutable = this._config.pythonExecutable;
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
}
}
}

View File

@@ -0,0 +1,143 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
export const subscriptionId: msRest.OperationURLParameter = {
parameterPath: 'subscriptionId',
mapper: {
required: true,
serializedName: 'subscriptionId',
type: {
name: 'String'
}
}
};
export const resourceGroupName: msRest.OperationURLParameter = {
parameterPath: 'resourceGroupName',
mapper: {
required: true,
serializedName: 'resourceGroupName',
type: {
name: 'String'
}
}
};
export const workspaceName: msRest.OperationURLParameter = {
parameterPath: 'workspaceName',
mapper: {
required: true,
serializedName: 'workspaceName',
type: {
name: 'String'
}
}
};
export const workspace: msRest.OperationURLParameter = {
parameterPath: 'workspace',
mapper: {
required: true,
serializedName: 'workspace',
type: {
name: 'String'
}
}
};
export const resourceGroup: msRest.OperationURLParameter = {
parameterPath: 'resourceGroup',
mapper: {
required: true,
serializedName: 'resourceGroup',
type: {
name: 'String'
}
}
};
export const id: msRest.OperationURLParameter = {
parameterPath: 'id',
mapper: {
required: true,
serializedName: 'id',
type: {
name: 'String'
}
}
};
export const acceptLanguage: msRest.OperationParameter = {
parameterPath: 'acceptLanguage',
mapper: {
serializedName: 'accept-language',
defaultValue: 'en-US',
type: {
name: 'String'
}
}
};
export const apiVersion: msRest.OperationQueryParameter = {
parameterPath: 'apiVersion',
mapper: {
required: true,
serializedName: 'api-version',
type: {
name: 'String'
}
}
};
export const origin: msRest.OperationURLParameter = {
parameterPath: 'origin',
mapper: {
required: true,
serializedName: 'origin',
type: {
name: 'String'
}
}
};
export const container: msRest.OperationURLParameter = {
parameterPath: 'container',
mapper: {
required: true,
serializedName: 'container',
type: {
name: 'String'
}
}
};
export const projectName0: msRest.OperationQueryParameter = {
parameterPath: [
'options',
'projectName'
],
mapper: {
serializedName: 'projectName',
type: {
name: 'String'
}
}
};
export const path1: msRest.OperationQueryParameter = {
parameterPath: [
'options',
'path'
],
mapper: {
serializedName: 'path',
type: {
name: 'String'
}
}
};
export const accountName: msRest.OperationQueryParameter = {
parameterPath: [
'options',
'accountName'
],
mapper: {
serializedName: 'accountName',
type: {
name: 'String'
}
}
};

View File

@@ -0,0 +1,195 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as utils from '../common/utils';
import { DatabaseTable } from '../prediction/interfaces';
import { ImportedModel } from './interfaces';
export function getDatabaseConfigureQuery(configTable: DatabaseTable): string {
return `
IF NOT EXISTS (
SELECT name
FROM sys.databases
WHERE name = N'${utils.doubleEscapeSingleQuotes(configTable.databaseName)}'
)
CREATE DATABASE [${utils.doubleEscapeSingleBrackets(configTable.databaseName)}]
`;
}
export function getDeployedModelsQuery(table: DatabaseTable): string {
return `
${selectQuery}
FROM ${utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '')}
WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
ORDER BY model_id
`;
}
/**
* Verifies config table has the expected schema
* @param databaseName
* @param tableName
*/
export function getConfigTableVerificationQuery(table: DatabaseTable): string {
let tableName = table.tableName;
let schemaName = table.schema;
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
return `
IF NOT EXISTS (
SELECT name
FROM sys.databases
WHERE name = N'${utils.doubleEscapeSingleQuotes(table.databaseName)}'
)
BEGIN
SELECT 1
END
ELSE
BEGIN
USE [${utils.doubleEscapeSingleBrackets(table.databaseName)}]
IF EXISTS
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
IF EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_name')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_id')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_description')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_framework_version')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_version')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_creation_time')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='model_deployment_time')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='deployed_by')
AND EXISTS (SELECT * FROM syscolumns WHERE ID=OBJECT_ID('${twoPartTableName}') AND NAME='run_id')
BEGIN
SELECT 1
END
ELSE
BEGIN
SELECT 0
END
END
ELSE
SELECT 1
END
`;
}
/**
* Creates the import table if doesn't exist
*/
export function getConfigureTableQuery(table: DatabaseTable): string {
let tableName = table.tableName;
let schemaName = table.schema;
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
return `
IF NOT EXISTS
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = '${utils.doubleEscapeSingleQuotes(tableName)}'
AND s.name = '${utils.doubleEscapeSingleQuotes(schemaName)}'
)
BEGIN
CREATE TABLE ${twoPartTableName}(
[model_id] [int] IDENTITY(1,1) NOT NULL,
[model_name] [varchar](256) NOT NULL,
[model_framework] [varchar](256) NULL,
[model_framework_version] [varchar](256) NULL,
[model] [varbinary](max) NOT NULL,
[model_version] [varchar](256) NULL,
[model_creation_time] [datetime2] NULL,
[model_deployment_time] [datetime2] NULL,
[deployed_by] [int] NULL,
[model_description] [varchar](256) NULL,
[run_id] [varchar](256) NULL,
CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_models_pk] PRIMARY KEY CLUSTERED
(
[model_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE ${twoPartTableName} ADD CONSTRAINT [${utils.doubleEscapeSingleBrackets(tableName)}_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
END
`;
}
export function getInsertModelQuery(model: ImportedModel, table: DatabaseTable): string {
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(table.tableName || '', table.schema || '');
const threePartTableName = utils.getRegisteredModelsThreePartsName(table.databaseName || '', table.tableName || '', table.schema || '');
let updateScript = `
INSERT INTO ${twoPartTableName}
(model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id)
VALUES (
'${utils.doubleEscapeSingleQuotes(model.modelName || '')}',
${utils.doubleEscapeSingleQuotes(model.content || '')},
'${utils.doubleEscapeSingleQuotes(model.version || '')}',
'${utils.doubleEscapeSingleQuotes(model.description || '')}',
'${utils.doubleEscapeSingleQuotes(model.created || '')}',
'${utils.doubleEscapeSingleQuotes(model.framework || '')}',
'${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
'${utils.doubleEscapeSingleQuotes(model.runId || '')}')
`;
return `
${updateScript}
${selectQuery}
FROM ${threePartTableName}
WHERE model_id = SCOPE_IDENTITY();
`;
}
export function getModelContentQuery(model: ImportedModel): string {
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
return `
SELECT model
FROM ${threePartTableName}
WHERE model_id = ${model.id};
`;
}
export function getUpdateModelQuery(model: ImportedModel): string {
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || '');
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
let updateScript = `
UPDATE ${twoPartTableName}
SET
model_name = '${utils.doubleEscapeSingleQuotes(model.modelName || '')}',
model_version = '${utils.doubleEscapeSingleQuotes(model.version || '')}',
model_description = '${utils.doubleEscapeSingleQuotes(model.description || '')}',
model_creation_time = '${utils.doubleEscapeSingleQuotes(model.created || '')}',
model_framework = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
model_framework_version = '${utils.doubleEscapeSingleQuotes(model.frameworkVersion || '')}',
run_id = '${utils.doubleEscapeSingleQuotes(model.runId || '')}'
WHERE model_id = ${model.id}`;
return `
${updateScript}
${selectQuery}
FROM ${threePartTableName}
WHERE model_id = ${model.id};
`;
}
export function getDeleteModelQuery(model: ImportedModel): string {
const twoPartTableName = utils.getRegisteredModelsTwoPartsName(model.table.tableName || '', model.table.schema || '');
const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || '');
let updateScript = `
Delete from ${twoPartTableName}
WHERE model_id = ${model.id}`;
return `
${updateScript}
${selectQuery}
FROM ${threePartTableName}
`;
}
export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id';

View File

@@ -0,0 +1,64 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as msRest from '@azure/ms-rest-js';
import { AzureMachineLearningWorkspacesContext } from '@azure/arm-machinelearningservices';
import * as Models from './interfaces';
import * as Mappers from './mappers';
import * as Parameters from './parameters';
/**
* Workspace models client
*/
export class WorkspaceModels {
private readonly client: AzureMachineLearningWorkspacesContext;
constructor(client: AzureMachineLearningWorkspacesContext) {
this.client = client;
}
listModels(resourceGroupName: string, workspaceName: string, options?: msRest.RequestOptionsBase): Promise<Models.ListWorkspaceModelsResult>;
listModels(resourceGroupName: string, workspaceName: string, callback: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): void;
listModels(resourceGroupName: string, workspaceName: string, options: msRest.RequestOptionsBase, callback: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): void;
listModels(resourceGroupName: string, workspaceName: string, options?: msRest.RequestOptionsBase | msRest.ServiceCallback<Models.ListWorkspaceModelsResult>, callback?: msRest.ServiceCallback<Models.ListWorkspaceModelsResult>): Promise<Models.WorkspacesModelsResponse> {
return this.client.sendOperationRequest(
{
resourceGroupName,
workspaceName,
options
},
listModelsOperationSpec,
callback) as Promise<Models.WorkspacesModelsResponse>;
}
}
const serializer = new msRest.Serializer(Mappers);
const listModelsOperationSpec: msRest.OperationSpec = {
httpMethod: 'GET',
path:
'modelmanagement/v1.0/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.MachineLearningServices/workspaces/{workspaceName}/models',
urlParameters: [
Parameters.subscriptionId,
Parameters.resourceGroupName,
Parameters.workspaceName
],
queryParameters: [
Parameters.apiVersion
],
headerParameters: [
Parameters.acceptLanguage
],
responses: {
200: {
bodyMapper: Mappers.ListWorkspaceModelsResult
},
default: {
bodyMapper: Mappers.MachineLearningServiceError
}
},
serializer
};

View File

@@ -0,0 +1,117 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import * as nbExtensionApis from '../typings/notebookServices';
import * as utils from '../common/utils';
export enum ScriptMode {
Install = 'install',
Uninstall = 'uninstall'
}
export abstract class SqlPackageManageProviderBase {
/**
* Base class for all SQL package managers
*/
constructor(protected _apiWrapper: ApiWrapper) {
}
/**
* Returns database names
*/
public async getLocations(): Promise<nbExtensionApis.IPackageLocation[]> {
let connection = await this.getCurrentConnection();
if (connection) {
let databases = await this._apiWrapper.listDatabases(connection.connectionId);
return databases.map(x => {
return { displayName: x, name: x };
});
}
return [];
}
protected async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
/**
* Installs given packages
* @param packages Packages to install
* @param useMinVersion minimum version
*/
public async installPackages(packages: nbExtensionApis.IPackageDetails[], useMinVersion: boolean, databaseName: string): Promise<void> {
if (packages) {
await Promise.all(packages.map(x => this.installPackage(x, useMinVersion, databaseName)));
}
//TODO: use useMinVersion
console.log(useMinVersion);
}
private async installPackage(packageDetail: nbExtensionApis.IPackageDetails, useMinVersion: boolean, databaseName: string): Promise<void> {
if (useMinVersion) {
let packageOverview = await this.getPackageOverview(packageDetail.name);
if (packageOverview && packageOverview.versions) {
let minVersion = packageOverview.versions[packageOverview.versions.length - 1];
packageDetail.version = minVersion;
}
}
await this.executeScripts(ScriptMode.Install, packageDetail, databaseName);
}
/**
* Uninstalls given packages
* @param packages Packages to uninstall
*/
public async uninstallPackages(packages: nbExtensionApis.IPackageDetails[], databaseName: string): Promise<void> {
if (packages) {
await Promise.all(packages.map(x => this.executeScripts(ScriptMode.Uninstall, x, databaseName)));
}
}
/**
* Returns package overview for given name
* @param packageName Package Name
*/
public async getPackageOverview(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
let packageOverview = await this.fetchPackage(packageName);
if (packageOverview && packageOverview.versions) {
packageOverview.versions = utils.sortPackageVersions(packageOverview.versions, false);
}
return packageOverview;
}
/**
* Returns list of packages
*/
public async listPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
let packages = await this.fetchPackages(databaseName);
if (packages) {
packages = packages.sort((a, b) => this.comparePackages(a, b));
} else {
packages = [];
}
return packages;
}
private comparePackages(p1: nbExtensionApis.IPackageDetails, p2: nbExtensionApis.IPackageDetails): number {
if (p1 && p2) {
let compare = p1.name.localeCompare(p2.name);
if (compare === 0) {
compare = utils.comparePackageVersions(p1.version, p2.version);
}
return compare;
}
return p1 ? 1 : -1;
}
protected abstract fetchPackage(packageName: string): Promise<nbExtensionApis.IPackageOverview>;
protected abstract fetchPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]>;
protected abstract executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void>;
}

View File

@@ -0,0 +1,95 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import { QueryRunner } from '../common/queryRunner';
import * as constants from '../common/constants';
import { ApiWrapper } from '../common/apiWrapper';
import * as utils from '../common/utils';
import * as nbExtensionApis from '../typings/notebookServices';
export class PackageManagementService {
/**
* Creates a new instance of ServerConfigManager
*/
constructor(
private _apiWrapper: ApiWrapper,
private _queryRunner: QueryRunner,
) {
}
/**
* Opens server config documents
*/
public async openDocuments(): Promise<boolean> {
return await this._apiWrapper.openExternal(vscode.Uri.parse(constants.mlsDocuments));
}
/**
* Returns true if mls is installed in the give SQL server instance
*/
public async isMachineLearningServiceEnabled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
return this._queryRunner.isMachineLearningServiceEnabled(connection);
}
/**
* Returns true if R installed in the give SQL server instance
*/
public async isRInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
return this._queryRunner.isRInstalled(connection);
}
/**
* Returns true if python installed in the give SQL server instance
*/
public async isPythonInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
return this._queryRunner.isPythonInstalled(connection);
}
/**
* Updates external script config
* @param connection SQL Connection
* @param enable if true external script will be enabled
*/
public async enableExternalScriptConfig(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
let current = await this._queryRunner.isMachineLearningServiceEnabled(connection);
if (current) {
return current;
}
let confirmed = await utils.promptConfirm(constants.confirmEnableExternalScripts, this._apiWrapper);
if (confirmed) {
await this._queryRunner.updateExternalScriptConfig(connection, true);
current = await this._queryRunner.isMachineLearningServiceEnabled(connection);
if (current) {
this._apiWrapper.showInfoMessage(constants.mlsEnabledMessage);
} else {
this._apiWrapper.showErrorMessage(constants.mlsConfigUpdateFailed);
}
} else {
this._apiWrapper.showErrorMessage(constants.externalScriptsIsRequiredError);
}
return current;
}
/**
* Returns python packages installed in SQL server instance
* @param connection SQL Connection
*/
public async getPythonPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
return this._queryRunner.getPythonPackages(connection, databaseName);
}
/**
* Returns python packages installed in SQL server instance
* @param connection SQL Connection
*/
public async getRPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
return this._queryRunner.getRPackages(connection, databaseName);
}
}

View File

@@ -0,0 +1,219 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as nbExtensionApis from '../typings/notebookServices';
import { SqlPythonPackageManageProvider } from './sqlPythonPackageManageProvider';
import * as utils from '../common/utils';
import * as constants from '../common/constants';
import { ApiWrapper } from '../common/apiWrapper';
import { ProcessService } from '../common/processService';
import { Config } from '../configurations/config';
import { isNullOrUndefined } from 'util';
import { SqlRPackageManageProvider } from './sqlRPackageManageProvider';
import { HttpClient } from '../common/httpClient';
import { PackageConfigModel } from '../configurations/packageConfigModel';
import { PackageManagementService } from './packageManagementService';
export class PackageManager {
private _sqlPythonPackagePackageManager: SqlPythonPackageManageProvider;
private _sqlRPackageManager: SqlRPackageManageProvider;
public dependenciesInstalled: boolean = false;
/**
* Creates a new instance of PackageManager
*/
constructor(
private _outputChannel: vscode.OutputChannel,
private _rootFolder: string,
private _apiWrapper: ApiWrapper,
private _service: PackageManagementService,
private _processService: ProcessService,
private _config: Config,
private _httpClient: HttpClient) {
this._sqlPythonPackagePackageManager = new SqlPythonPackageManageProvider(this._outputChannel, this._apiWrapper, this._service, this._processService, this._config, this._httpClient);
this._sqlRPackageManager = new SqlRPackageManageProvider(this._outputChannel, this._apiWrapper, this._service, this._processService, this._config, this._httpClient);
}
/**
* Initializes the instance and resister SQL package manager with manage package dialog
*/
public init(): void {
}
private get pythonExecutable(): string {
return this._config.pythonExecutable;
}
private get _rExecutable(): string {
return this._config.rExecutable;
}
/**
* Returns packageManageProviders
*/
public get packageManageProviders(): nbExtensionApis.IPackageManageProvider[] {
return [
this._sqlPythonPackagePackageManager,
this._sqlRPackageManager
];
}
/**
* Executes manage package command for SQL server packages.
*/
public async managePackages(): Promise<void> {
try {
await this.enableExternalScript();
// Only execute the command if there's a valid connection with ml configuration enabled
//
let connection = await this.getCurrentConnection();
let isPythonInstalled = await this._service.isPythonInstalled(connection);
let isRInstalled = await this._service.isRInstalled(connection);
let defaultProvider: SqlRPackageManageProvider | SqlPythonPackageManageProvider | undefined;
if (connection && isPythonInstalled && this._sqlPythonPackagePackageManager.canUseProvider) {
defaultProvider = this._sqlPythonPackagePackageManager;
} else if (connection && isRInstalled && this._sqlRPackageManager.canUseProvider) {
defaultProvider = this._sqlRPackageManager;
}
if (connection && defaultProvider) {
await this.enableExternalScript();
// Install dependencies
//
if (!this.dependenciesInstalled) {
await this.installDependencies();
this.dependenciesInstalled = true;
}
// Execute the command
//
this._apiWrapper.executeCommand(constants.managePackagesCommand, {
defaultLocation: defaultProvider.packageTarget.location,
defaultProviderId: defaultProvider.providerId
});
} else {
this._apiWrapper.showInfoMessage(constants.managePackageCommandError);
}
} catch (err) {
this._apiWrapper.showErrorMessage(err);
}
}
public async enableExternalScript(): Promise<void> {
let connection = await this.getCurrentConnection();
if (!await this._service.enableExternalScriptConfig(connection)) {
throw Error(constants.externalScriptsIsRequiredError);
}
}
/**
* Installs dependencies for the extension
*/
public async installDependencies(): Promise<void> {
await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
this.installRequiredRPackages()], true);
}
private async installRequiredRPackages(): Promise<void> {
if (!this._config.rEnabled) {
return;
}
if (!this._rExecutable) {
throw new Error(constants.rConfigError);
}
await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder));
await Promise.all(this._config.requiredSqlRPackages.map(x => this.installRPackage(x)));
}
/**
* Installs required python packages
*/
public async installRequiredPythonPackages(requiredPackages: PackageConfigModel[]): Promise<void> {
if (!this._config.pythonEnabled) {
return;
}
if (!this.pythonExecutable) {
throw new Error(constants.pythonConfigError);
}
if (!requiredPackages || requiredPackages.length === 0) {
return;
}
let installedPackages = await this.getInstalledPipPackages();
let fileContent = '';
requiredPackages.forEach(packageDetails => {
let hasVersion = ('version' in packageDetails) && !isNullOrUndefined(packageDetails['version']) && packageDetails['version'].length > 0;
if (!installedPackages.find(x => x.name === packageDetails['name']
&& (!hasVersion || utils.comparePackageVersions(packageDetails['version'] || '', x.version) <= 0))) {
let packageNameDetail = hasVersion ? `${packageDetails.name}==${packageDetails.version}` : `${packageDetails.name}`;
fileContent = `${fileContent}${packageNameDetail}\n`;
}
});
if (fileContent) {
let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages(fileContent), this._apiWrapper);
if (confirmed) {
this._outputChannel.appendLine(constants.installDependenciesPackages);
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
return await this.installPipPackage(tempFilePath);
});
this._outputChannel.appendLine(result);
} else {
throw Error(constants.requiredPackagesNotInstalled);
}
} else {
this._outputChannel.appendLine(constants.installDependenciesPackagesAlreadyInstalled);
}
}
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
try {
let cmd = `"${this.pythonExecutable}" -m pip list --format=json`;
let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined);
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
if (packagesInfo && packagesInfo.indexOf(']') > 0) {
packagesResult = <nbExtensionApis.IPackageDetails[]>JSON.parse(packagesInfo.substr(0, packagesInfo.indexOf(']') + 1));
}
return packagesResult;
}
catch (err) {
this._outputChannel.appendLine(constants.installDependenciesGetPackagesError(err ? err.message : ''));
return [];
}
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
private async installPipPackage(requirementFilePath: string): Promise<string> {
let cmd = `"${this.pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
}
private async installRPackage(model: PackageConfigModel): Promise<string> {
let output = '';
let cmd = '';
if (model.downloadUrl) {
const packageFile = utils.getPackageFilePath(this._rootFolder, model.fileName || model.name);
const packageExist = await utils.exists(packageFile);
if (!packageExist) {
await this._httpClient.download(model.downloadUrl, packageFile, this._outputChannel);
}
cmd = `"${this._rExecutable}" CMD INSTALL ${packageFile}`;
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
} else if (model.repository) {
cmd = `"${this._rExecutable}" -e "install.packages('${model.name}', repos='${model.repository}')"`;
output = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
}
return output;
}
}

View File

@@ -0,0 +1,136 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as nbExtensionApis from '../typings/notebookServices';
import { ApiWrapper } from '../common/apiWrapper';
import { ProcessService } from '../common/processService';
import { Config } from '../configurations/config';
import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProviderBase';
import { HttpClient } from '../common/httpClient';
import * as utils from '../common/utils';
import { PackageManagementService } from './packageManagementService';
/**
* Manage Package Provider for python packages inside SQL server databases
*/
export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase implements nbExtensionApis.IPackageManageProvider {
public static ProviderId = 'sql_Python';
/**
* Creates new a instance
*/
constructor(
private _outputChannel: vscode.OutputChannel,
apiWrapper: ApiWrapper,
private _service: PackageManagementService,
private _processService: ProcessService,
private _config: Config,
private _httpClient: HttpClient) {
super(apiWrapper);
}
/**
* Returns provider Id
*/
public get providerId(): string {
return SqlPythonPackageManageProvider.ProviderId;
}
/**
* Returns package target
*/
public get packageTarget(): nbExtensionApis.IPackageTarget {
return { location: 'SQL', packageType: 'Python' };
}
/**
* Returns list of packages
*/
protected async fetchPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
return await this._service.getPythonPackages(await this.getCurrentConnection(), databaseName);
}
/**
* Execute a script to install or uninstall a python package inside current SQL Server connection
* @param packageDetails Packages to install or uninstall
* @param scriptMode can be 'install' or 'uninstall'
*/
protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void> {
let connection = await this.getCurrentConnection();
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
if (connection) {
let port = '1433';
let server = connection.serverName;
let database = databaseName ? `, database="${databaseName}"` : '';
let index = connection.serverName.indexOf(',');
if (index > 0) {
port = connection.serverName.substring(index + 1);
server = connection.serverName.substring(0, index);
}
let pythonConnectionParts = `server="${server}", port=${port}, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"${database})`;
let pythonCommandScript = scriptMode === ScriptMode.Install ?
`pkgmanager.install(package="${packageDetails.name}", version="${packageDetails.version}")` :
`pkgmanager.uninstall(package_name="${packageDetails.name}")`;
let scripts: string[] = [
'import sqlmlutils',
`connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 17 for SQL Server", ${pythonConnectionParts}`,
'pkgmanager = sqlmlutils.SQLPackageManager(connection)',
pythonCommandScript
];
let pythonExecutable = this._config.pythonExecutable;
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
}
}
/**
* Returns true if the provider can be used
*/
async canUseProvider(): Promise<boolean> {
if (!this._config.pythonEnabled) {
return false;
}
let connection = await this.getCurrentConnection();
if (connection && await this._service.isPythonInstalled(connection)) {
return true;
}
return false;
}
private getPackageLink(packageName: string): string {
return `https://pypi.org/pypi/${packageName}/json`;
}
protected async fetchPackage(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
let body = await this._httpClient.fetch(this.getPackageLink(packageName));
let packagesJson = JSON.parse(body);
let versionNums: string[] = [];
let packageSummary = '';
if (packagesJson) {
if (packagesJson.releases) {
let versionKeys = Object.keys(packagesJson.releases);
versionKeys = versionKeys.filter(versionKey => {
let releaseInfo = packagesJson.releases[versionKey];
return Array.isArray(releaseInfo) && releaseInfo.length > 0;
});
versionNums = utils.sortPackageVersions(versionKeys, false);
}
if (packagesJson.info && packagesJson.info.summary) {
packageSummary = packagesJson.info.summary;
}
}
return {
name: packageName,
versions: versionNums,
summary: packageSummary
};
}
}

View File

@@ -0,0 +1,123 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as nbExtensionApis from '../typings/notebookServices';
import { ApiWrapper } from '../common/apiWrapper';
import { ProcessService } from '../common/processService';
import { Config } from '../configurations/config';
import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProviderBase';
import { HttpClient } from '../common/httpClient';
import * as constants from '../common/constants';
import { PackageManagementService } from './packageManagementService';
/**
* Manage Package Provider for r packages inside SQL server databases
*/
export class SqlRPackageManageProvider extends SqlPackageManageProviderBase implements nbExtensionApis.IPackageManageProvider {
public static ProviderId = 'sql_R';
/**
* Creates new a instance
*/
constructor(
private _outputChannel: vscode.OutputChannel,
apiWrapper: ApiWrapper,
private _service: PackageManagementService,
private _processService: ProcessService,
private _config: Config,
private _httpClient: HttpClient) {
super(apiWrapper);
}
/**
* Returns provider Id
*/
public get providerId(): string {
return SqlRPackageManageProvider.ProviderId;
}
/**
* Returns package target
*/
public get packageTarget(): nbExtensionApis.IPackageTarget {
return { location: 'SQL', packageType: 'R' };
}
/**
* Returns list of packages
*/
protected async fetchPackages(databaseName: string): Promise<nbExtensionApis.IPackageDetails[]> {
return await this._service.getRPackages(await this.getCurrentConnection(), databaseName);
}
/**
* Execute a script to install or uninstall a r package inside current SQL Server connection
* @param packageDetails Packages to install or uninstall
* @param scriptMode can be 'install' or 'uninstall'
*/
protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void> {
let connection = await this.getCurrentConnection();
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
if (connection) {
let database = databaseName ? `, database="${databaseName}"` : '';
let connectionParts = `server="${connection.serverName}", uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"${database}`;
let rCommandScript = scriptMode === ScriptMode.Install ? 'sql_install.packages' : 'sql_remove.packages';
let scripts: string[] = [
'formals(quit)$save <- formals(q)$save <- "no"',
'library(sqlmlutils)',
`connection <- connectionInfo(${connectionParts})`,
`r = getOption("repos")`,
`r["CRAN"] = "${this._config.rPackagesRepository}"`,
`options(repos = r)`,
`pkgs <- c("${packageDetails.name}")`,
`${rCommandScript}(connectionString = connection, pkgs, scope = "PUBLIC")`,
'q()'
];
let rExecutable = this._config.rExecutable;
await this._processService.execScripts(`${rExecutable}`, scripts, ['--vanilla'], this._outputChannel);
}
}
/**
* Returns true if the provider can be used
*/
async canUseProvider(): Promise<boolean> {
if (!this._config.rEnabled) {
return false;
}
let connection = await this.getCurrentConnection();
if (connection && await this._service.isRInstalled(connection)) {
return true;
}
return false;
}
private getPackageLink(packageName: string): string {
return `${this._config.rPackagesRepository}/web/packages/${packageName}`;
}
/**
* Returns package overview for given name
* @param packageName Package Name
*/
protected async fetchPackage(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
let packagePreview: nbExtensionApis.IPackageOverview = {
name: packageName,
versions: [constants.latestVersion],
summary: ''
};
await this._httpClient.fetch(this.getPackageLink(packageName));
return packagePreview;
}
}

View File

@@ -0,0 +1,27 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
export interface TableColumn {
columnName: string;
dataType?: string;
}
export interface PredictColumn extends TableColumn {
paramName?: string;
}
export interface DatabaseTable {
databaseName: string | undefined;
tableName: string | undefined;
schema: string | undefined
}
export interface PredictInputParameters extends DatabaseTable {
inputColumns: PredictColumn[] | undefined
}
export interface PredictParameters extends PredictInputParameters {
outputColumns: PredictColumn[] | undefined
}

View File

@@ -0,0 +1,217 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import * as utils from '../common/utils';
import { ImportedModel } from '../modelManagement/interfaces';
import { PredictParameters, PredictColumn, DatabaseTable, TableColumn } from '../prediction/interfaces';
/**
* Service to make prediction
*/
export class PredictService {
/**
* Creates new instance
*/
constructor(
private _apiWrapper: ApiWrapper,
private _queryRunner: QueryRunner) {
}
/**
* Returns the list of databases
*/
public async getDatabaseList(): Promise<string[]> {
let connection = await this.getCurrentConnection();
if (connection) {
return await this._apiWrapper.listDatabases(connection.connectionId);
}
return [];
}
/**
* Generates prediction script given model info and predict parameters
* @param predictParams predict parameters
* @param registeredModel model parameters
*/
public async generatePredictScript(
predictParams: PredictParameters,
registeredModel: ImportedModel | undefined,
filePath: string | undefined
): Promise<string> {
let connection = await this.getCurrentConnection();
let query = '';
if (registeredModel && registeredModel.id) {
query = this.getPredictScriptWithModelId(
registeredModel.id,
predictParams.inputColumns || [],
predictParams.outputColumns || [],
predictParams,
registeredModel.table);
} else if (filePath) {
let modelBytes = await utils.readFileInHex(filePath || '');
query = this.getPredictScriptWithModelBytes(modelBytes, predictParams.inputColumns || [],
predictParams.outputColumns || [],
predictParams);
}
let document = await this._apiWrapper.openTextDocument({
language: 'sql',
content: query
});
await this._apiWrapper.showTextDocument(document.uri);
await this._apiWrapper.connect(document.uri.toString(), connection.connectionId);
this._apiWrapper.runQuery(document.uri.toString(), undefined, false);
return query;
}
/**
* Returns list of tables given database name
* @param databaseName database name
*/
public async getTableList(databaseName: string): Promise<DatabaseTable[]> {
let connection = await this.getCurrentConnection();
let list: DatabaseTable[] = [];
if (connection) {
let query = utils.getScriptWithDBChange(connection.databaseName, databaseName, this.getTablesScript(databaseName));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push({
databaseName: databaseName,
tableName: row[0].displayValue,
schema: row[1].displayValue
});
});
}
}
return list;
}
/**
*Returns list of column names of a database
* @param databaseTable table info
*/
public async getTableColumnsList(databaseTable: DatabaseTable): Promise<TableColumn[]> {
let connection = await this.getCurrentConnection();
let list: TableColumn[] = [];
if (connection && databaseTable.databaseName) {
const query = utils.getScriptWithDBChange(connection.databaseName, databaseTable.databaseName, this.getTableColumnsScript(databaseTable));
let result = await this._queryRunner.safeRunQuery(connection, query);
if (result && result.rows && result.rows.length > 0) {
result.rows.forEach(row => {
list.push({
columnName: row[0].displayValue,
dataType: row[1].displayValue
});
});
}
}
return list;
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
private getTableColumnsScript(databaseTable: DatabaseTable): string {
return `
SELECT COLUMN_NAME,DATA_TYPE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME='${utils.doubleEscapeSingleQuotes(databaseTable.tableName)}'
AND TABLE_SCHEMA='${utils.doubleEscapeSingleQuotes(databaseTable.schema)}'
AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseTable.databaseName)}'
`;
}
private getTablesScript(databaseName: string): string {
return `
SELECT TABLE_NAME,TABLE_SCHEMA
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG='${utils.doubleEscapeSingleQuotes(databaseName)}'
`;
}
private getPredictScriptWithModelId(
modelId: number,
columns: PredictColumn[],
outputColumns: PredictColumn[],
sourceTable: DatabaseTable,
importTable: DatabaseTable): string {
const threePartTableName = utils.getRegisteredModelsThreePartsName(importTable.databaseName || '', importTable.tableName || '', importTable.schema || '');
return `
DECLARE @model VARBINARY(max) = (
SELECT model
FROM ${threePartTableName}
WHERE model_id = ${modelId}
);
WITH predict_input
AS (
SELECT TOP 1000
${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi
)
SELECT
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = @model, DATA = predict_input)
WITH (
${this.getOutputParameters(outputColumns)}
) AS p
`;
}
private getPredictScriptWithModelBytes(
modelBytes: string,
columns: PredictColumn[],
outputColumns: PredictColumn[],
databaseNameTable: DatabaseTable): string {
return `
WITH predict_input
AS (
SELECT TOP 1000
${this.getInputColumnNames(columns, 'pi')}
FROM [${utils.doubleEscapeSingleBrackets(databaseNameTable.databaseName)}].[${databaseNameTable.schema}].[${utils.doubleEscapeSingleBrackets(databaseNameTable.tableName)}] as pi
)
SELECT
${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getOutputColumnNames(outputColumns, 'p')}
FROM PREDICT(MODEL = ${modelBytes}, DATA = predict_input)
WITH (
${this.getOutputParameters(outputColumns)}
) AS p
`;
}
private getInputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return this.getColumnName(tableName, c.paramName || '', c.columnName);
}).join(',\n');
}
private getOutputColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return this.getColumnName(tableName, c.columnName, c.paramName || '');
}).join(',\n');
}
private getColumnName(tableName: string, columnName: string, displayName: string) {
return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`;
}
private getPredictColumnNames(columns: PredictColumn[], tableName: string) {
return columns.map(c => {
return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`;
}).join(',\n');
}
private getOutputParameters(columns: PredictColumn[]) {
return columns.map(c => {
return `${c.paramName} ${c.dataType}`;
}).join(',\n');
}
}

View File

@@ -0,0 +1,67 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import { ProcessService } from '../../common/processService';
import * as utils from '../../common/utils';
import should = require('should');
interface TestContext {
outputChannel: vscode.OutputChannel;
}
function createContext(): TestContext {
return {
outputChannel: {
name: '',
append: () => { },
appendLine: () => { },
clear: () => { },
show: () => { },
hide: () => { },
dispose: () => { }
}
};
}
function execFolderListCommand(context: TestContext, service : ProcessService): Promise<string> {
if (utils.isWindows()) {
return service.execScripts('cmd', ['dir', '.'], [], context.outputChannel);
} else {
return service.execScripts('/bin/sh', ['-c', 'ls'], [], context.outputChannel);
}
}
function execFolderListBufferedCommand(context: TestContext, service : ProcessService): Promise<string> {
if (utils.isWindows()) {
return service.executeBufferedCommand('dir', context.outputChannel);
} else {
return service.executeBufferedCommand('ls', context.outputChannel);
}
}
describe('Process Service', () => {
it('Executing a valid script should return successfully', async function (): Promise<void> {
const context = createContext();
let service = new ProcessService();
await should(execFolderListCommand(context, service)).resolved();
});
it('execFolderListCommand should reject if command time out @UNSTABLE@', async function (): Promise<void> {
const context = createContext();
let service = new ProcessService();
service.timeout = 10;
await should(execFolderListCommand(context, service)).rejected();
});
it('executeBufferedCommand should resolve give valid script', async function (): Promise<void> {
const context = createContext();
let service = new ProcessService();
service.timeout = 2000;
await should(execFolderListBufferedCommand(context, service)).resolved();
});
});

View File

@@ -0,0 +1,48 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as path from 'path';
const testRunner = require('vscodetestcover');
const suite = 'machine learning Extension Tests';
const mochaOptions: any = {
ui: 'bdd',
useColors: true,
timeout: 10000
};
// set relevant mocha options from the environment
if (process.env.ADS_TEST_GREP) {
mochaOptions.grep = process.env.ADS_TEST_GREP;
console.log(`setting options.grep to: ${mochaOptions.grep}`);
}
if (process.env.ADS_TEST_INVERT_GREP) {
mochaOptions.invert = parseInt(process.env.ADS_TEST_INVERT_GREP);
console.log(`setting options.invert to: ${mochaOptions.invert}`);
}
if (process.env.ADS_TEST_TIMEOUT) {
mochaOptions.timeout = parseInt(process.env.ADS_TEST_TIMEOUT);
console.log(`setting options.timeout to: ${mochaOptions.timeout}`);
}
if (process.env.ADS_TEST_RETRIES) {
mochaOptions.retries = parseInt(process.env.ADS_TEST_RETRIES);
console.log(`setting options.retries to: ${mochaOptions.retries}`);
}
if (process.env.BUILD_ARTIFACTSTAGINGDIRECTORY) {
mochaOptions.reporter = 'mocha-multi-reporters';
mochaOptions.reporterOptions = {
reporterEnabled: 'spec, mocha-junit-reporter',
mochaJunitReporterReporterOptions: {
testsuitesTitle: `${suite} ${process.platform}`,
mochaFile: path.join(process.env.BUILD_ARTIFACTSTAGINGDIRECTORY, `test-results/${process.platform}-${suite.toLowerCase().replace(/[^\w]/g, '-')}-results.xml`)
}
};
}
testRunner.configure(mochaOptions, { coverConfig: '../../coverConfig.json' });
export = testRunner;

View File

@@ -0,0 +1,150 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import * as path from 'path';
import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import { ProcessService } from '../common/processService';
import MainController from '../controllers/mainController';
import { PackageManager } from '../packageManagement/packageManager';
import * as nbExtensionApis from '../typings/notebookServices';
interface TestContext {
notebookExtension: vscode.Extension<any>;
jupyterInstallation: nbExtensionApis.IJupyterServerInstallation;
jupyterController: nbExtensionApis.IJupyterController;
nbExtensionApis: nbExtensionApis.IExtensionApi;
apiWrapper: TypeMoq.IMock<ApiWrapper>;
queryRunner: TypeMoq.IMock<QueryRunner>;
processService: TypeMoq.IMock<ProcessService>;
context: vscode.ExtensionContext;
outputChannel: vscode.OutputChannel;
extension: vscode.Extension<any>;
packageManager: TypeMoq.IMock<PackageManager>;
workspaceConfig: vscode.WorkspaceConfiguration;
}
function createContext(): TestContext {
let packages = new Map<string, nbExtensionApis.IPackageManageProvider>();
let jupyterInstallation: nbExtensionApis.IJupyterServerInstallation = {
installCondaPackages: () => { return Promise.resolve(); },
getInstalledPipPackages: () => { return Promise.resolve([]); },
installPipPackages: () => { return Promise.resolve(); },
uninstallPipPackages: () => { return Promise.resolve(); },
uninstallCondaPackages: () => { return Promise.resolve(); },
executeBufferedCommand: () => { return Promise.resolve(''); },
executeStreamedCommand: () => { return Promise.resolve(); },
pythonExecutable: '',
pythonInstallationPath: '',
installPythonPackage: () => { return Promise.resolve(); }
};
let jupyterController = {
jupyterInstallation: jupyterInstallation
};
let extensionPath = path.join(__dirname, '..', '..');
let extensionApi: nbExtensionApis.IExtensionApi = {
getJupyterController: () => { return jupyterController; },
registerPackageManager: (providerId: string, packageManagerProvider: nbExtensionApis.IPackageManageProvider) => {
packages.set(providerId, packageManagerProvider);
},
getPackageManagers: () => { return packages; },
};
return {
jupyterInstallation: jupyterInstallation,
jupyterController: jupyterController,
nbExtensionApis: extensionApi,
notebookExtension: {
id: '',
extensionPath: '',
isActive: true,
packageJSON: '',
extensionKind: vscode.ExtensionKind.UI,
exports: extensionApi,
activate: () => {return Promise.resolve();},
extensionUri: vscode.Uri.parse('')
},
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
processService: TypeMoq.Mock.ofType(ProcessService),
packageManager: TypeMoq.Mock.ofType(PackageManager),
context: {
subscriptions: [],
workspaceState: {
get: () => {return undefined;},
update: () => {return Promise.resolve();}
},
globalState: {
get: () => {return Promise.resolve();},
update: () => {return Promise.resolve();}
},
extensionPath: extensionPath,
asAbsolutePath: () => {return '';},
storagePath: '',
globalStoragePath: '',
logPath: '',
extensionUri: vscode.Uri.parse('')
},
outputChannel: {
name: '',
append: () => { },
appendLine: () => { },
clear: () => { },
show: () => { },
hide: () => { },
dispose: () => { }
},
extension: {
id: '',
extensionPath: '',
isActive: true,
packageJSON: {},
extensionKind: vscode.ExtensionKind.UI,
exports: {},
activate: () => { return Promise.resolve(); },
extensionUri: vscode.Uri.parse('')
},
workspaceConfig: {
get: () => {return 'value';},
has: () => {return true;},
inspect: () => {return undefined;},
update: () => {return Promise.reject();},
}
};
}
function createController(testContext: TestContext): MainController {
let controller = new MainController(testContext.context, testContext.apiWrapper.object, testContext.queryRunner.object, testContext.processService.object, testContext.packageManager.object);
return controller;
}
describe('Main Controller', () => {
it('Should create new instance successfully', async function (): Promise<void> {
let testContext = createContext();
testContext.apiWrapper.setup(x => x.createOutputChannel(TypeMoq.It.isAny())).returns(() => testContext.outputChannel);
should.doesNotThrow(() => createController(testContext));
});
it('initialize Should install dependencies successfully', async function (): Promise<void> {
let testContext = createContext();
testContext.apiWrapper.setup(x => x.getExtension(TypeMoq.It.isAny())).returns(() => testContext.notebookExtension);
testContext.apiWrapper.setup(x => x.getConfiguration(TypeMoq.It.isAny())).returns(() => testContext.workspaceConfig);
testContext.apiWrapper.setup(x => x.createOutputChannel(TypeMoq.It.isAny())).returns(() => testContext.outputChannel);
testContext.apiWrapper.setup(x => x.getExtension(TypeMoq.It.isAny())).returns(() => testContext.extension);
testContext.packageManager.setup(x => x.managePackages()).returns(() => Promise.resolve());
testContext.packageManager.setup(x => x.installDependencies()).returns(() => Promise.resolve());
testContext.apiWrapper.setup(x => x.registerCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny()));
let controller = createController(testContext);
await controller.activate();
should.notEqual(controller.config.requiredSqlPythonPackages.find(x => x.name ==='sqlmlutils'), undefined);
});
});

View File

@@ -0,0 +1,232 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Config } from '../../configurations/config';
import { HttpClient } from '../../common/httpClient';
import { azureResource } from '../../typings/azure-resource';
import * as utils from '../utils';
import { Workspace, WorkspacesListByResourceGroupResponse } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel, AssetsQueryByIdResponse, Asset, GetArtifactContentInformation2Response } from '../../modelManagement/interfaces';
import { AzureMachineLearningWorkspaces, Workspaces } from '@azure/arm-machinelearningservices';
import { WorkspaceModels } from '../../modelManagement/workspacesModels';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
config: TypeMoq.IMock<Config>;
httpClient: TypeMoq.IMock<HttpClient>;
outputChannel: vscode.OutputChannel;
op: azdata.BackgroundOperation;
accounts: azdata.Account[];
subscriptions: azureResource.AzureResourceSubscription[];
groups: azureResource.AzureResourceResourceGroup[];
workspaces: Workspace[];
models: WorkspaceModel[];
client: TypeMoq.IMock<AzureMachineLearningWorkspaces>;
workspacesClient: TypeMoq.IMock<Workspaces>;
modelClient: TypeMoq.IMock<WorkspaceModels>;
}
function createContext(): TestContext {
const context = utils.createContext();
const workspaces = TypeMoq.Mock.ofType(Workspaces);
const credentials = {
signRequest: () => {
return Promise.resolve(undefined!!);
}
};
const client = TypeMoq.Mock.ofInstance(new AzureMachineLearningWorkspaces(credentials, 'subscription'));
client.setup(x => x.apiVersion).returns(() => '20180101');
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
httpClient: TypeMoq.Mock.ofType(HttpClient),
outputChannel: context.outputChannel,
op: context.op,
accounts: [
{
key: {
providerId: '',
accountId: 'a1'
},
displayInfo: {
contextualDisplayName: '',
accountType: '',
displayName: 'a1',
userId: 'a1'
},
properties:
{
tenants: [
{
id: '1',
}
]
}
,
isStale: true
}
],
subscriptions: [
{
name: 's1',
id: 's1'
}
],
groups: [
{
name: 'g1',
id: 'g1'
}
],
workspaces: [{
name: 'w1',
id: 'w1'
}
],
models: [
{
name: 'm1',
id: 'm1',
url: 'aml://asset/test.test'
}
],
client: client,
workspacesClient: workspaces,
modelClient: TypeMoq.Mock.ofInstance(new WorkspaceModels(client.object))
};
}
describe('AzureModelRegistryService', () => {
it('getAccounts should return the list of accounts successfully', async function (): Promise<void> {
let testContext = createContext();
const accounts = testContext.accounts;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
testContext.apiWrapper.setup(x => x.getAllAccounts()).returns(() => Promise.resolve(accounts));
let actual = await service.getAccounts();
should.deepEqual(actual, testContext.accounts);
});
it('getSubscriptions should return the list of subscriptions successfully', async function (): Promise<void> {
let testContext = createContext();
const expected = testContext.subscriptions;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ subscriptions: expected, errors: [] }));
let actual = await service.getSubscriptions(testContext.accounts[0]);
should.deepEqual(actual, expected);
});
it('getGroups should return the list of groups successfully', async function (): Promise<void> {
let testContext = createContext();
const expected = testContext.groups;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ resourceGroups: expected, errors: [] }));
let actual = await service.getGroups(testContext.accounts[0], testContext.subscriptions[0]);
should.deepEqual(actual, expected);
});
it('getWorkspaces should return the list of workspaces successfully', async function (): Promise<void> {
let testContext = createContext();
const response: WorkspacesListByResourceGroupResponse = Object.assign(new Array<Workspace>(...testContext.workspaces), {
_response: undefined!
});
const expected = testContext.workspaces;
testContext.workspacesClient.setup(x => x.listByResourceGroup(TypeMoq.It.isAny())).returns(() => Promise.resolve(response));
testContext.workspacesClient.setup(x => x.listBySubscription()).returns(() => Promise.resolve(response));
testContext.client.setup(x => x.workspaces).returns(() => testContext.workspacesClient.object);
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
service.AzureMachineLearningClient = testContext.client.object;
let actual = await service.getWorkspaces(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0]);
should.deepEqual(actual, expected);
});
it('getModels should return the list of models successfully', async function (): Promise<void> {
let testContext = createContext();
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
const expected = testContext.models;
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
service.AzureMachineLearningClient = testContext.client.object;
service.ModelClient = testContext.modelClient.object;
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
let actual = await service.getModels(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0]);
should.deepEqual(actual, expected);
});
it('downloadModel should download model artifact successfully', async function (): Promise<void> {
let testContext = createContext();
const asset: Asset =
{
id: '1',
name: 'asset',
artifacts: [
{
id: '/1/2/3/4/5/'
}
]
};
const assetResponse: AssetsQueryByIdResponse = Object.assign(asset, {
_response: undefined!
});
const artifactResponse: GetArtifactContentInformation2Response = Object.assign({
contentUri: 'downloadUrl'
}, {
_response: undefined!
});
testContext.config.setup(x => x.amlApiVersion).returns(() => '2018');
testContext.config.setup(x => x.amlModelManagementUrl).returns(() => 'test.url');
testContext.config.setup(x => x.amlExperienceUrl).returns(() => 'test.url');
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('modelmanagement')), TypeMoq.It.isAny())).returns(() => Promise.resolve(assetResponse));
testContext.client.setup(x => x.sendOperationRequest(TypeMoq.It.isAny(),
TypeMoq.It.is(p => p.path !== undefined && p.path.startsWith('artifact')), TypeMoq.It.isAny())).returns(() => Promise.resolve(artifactResponse));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
let service = new AzureModelRegistryService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.httpClient.object,
testContext.outputChannel);
service.AzureMachineLearningClient = testContext.client.object;
service.ModelClient = testContext.modelClient.object;
testContext.modelClient.setup(x => x.listModels(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(testContext.models));
let actual = await service.downloadModel(testContext.accounts[0], testContext.subscriptions[0], testContext.groups[0], testContext.workspaces[0], testContext.models[0]);
should.notEqual(actual, undefined);
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
});

View File

@@ -0,0 +1,453 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as utils from '../../common/utils';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { Config } from '../../configurations/config';
import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { QueryRunner } from '../../common/queryRunner';
import { ImportedModel } from '../../modelManagement/interfaces';
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
import * as path from 'path';
import * as os from 'os';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as fs from 'fs';
import { ModelConfigRecent } from '../../modelManagement/modelConfigRecent';
import { DatabaseTable } from '../../prediction/interfaces';
import * as queries from '../../modelManagement/queries';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
config: TypeMoq.IMock<Config>;
queryRunner: TypeMoq.IMock<QueryRunner>;
modelClient: TypeMoq.IMock<ModelPythonClient>;
recentModels: TypeMoq.IMock<ModelConfigRecent>;
importTable: DatabaseTable;
}
function createContext(): TestContext {
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
modelClient: TypeMoq.Mock.ofType(ModelPythonClient),
recentModels: TypeMoq.Mock.ofType(ModelConfigRecent),
importTable: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
};
}
describe('DeployedModelService', () => {
it('getDeployedModels should fail with no connection', async function (): Promise<void> {
const testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
let importTable: DatabaseTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
await should(service.getDeployedModels(importTable)).rejected();
});
it('getDeployedModels should returns models successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const expected: ImportedModel[] = [
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
deploymentTime: '2018-01-01',
version: '1.1',
framework: 'onnx',
frameworkVersion: '1',
deployedBy: '1',
runId: 'run1',
table: testContext.importTable
}
];
const result = {
rowCount: 1,
columnInfo: [],
rows: [
[
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'name1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'desc1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1.1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'onnx',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'run1',
isNull: false,
invariantCultureDisplayValue: ''
}
]
]
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
const actual = await service.getDeployedModels(testContext.importTable);
should.deepEqual(actual, expected);
});
it('loadModelParameters should load parameters using python client successfully', async function (): Promise<void> {
const testContext = createContext();
const expected = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
},
]
};
testContext.modelClient.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(expected));
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
const actual = await service.loadModelParameters('');
should.deepEqual(actual, expected);
});
it('downloadModel should download model successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, 'test');
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
deploymentTime: '2018-01-01',
version: '1.1',
framework: 'onnx',
frameworkVersion: '1',
deployedBy: '1',
runId: 'run1',
table: testContext.importTable
};
const result = {
rowCount: 1,
columnInfo: [],
rows: [
[
{
displayValue: await utils.readFileInHex(tempFilePath),
isNull: false,
invariantCultureDisplayValue: ''
}
]
]
};
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
const actual = await service.downloadModel(model);
should.notEqual(actual, undefined);
});
it('deployLocalModel should returns models successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
deploymentTime: '2018-01-01',
version: '1.1',
framework: 'onnx',
frameworkVersion: '1',
deployedBy: '1',
runId: 'run1',
table: testContext.importTable
};
const row = [
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'name1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'desc1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1.1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'onnx',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '2018-01-01',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'run1',
isNull: false,
invariantCultureDisplayValue: ''
}
];
const result = {
rowCount: 1,
columnInfo: [],
rows: [row]
};
let updatedResult = {
rowCount: 1,
columnInfo: [],
rows: [row, row]
};
let deployed = false;
let service = new DeployedModelService(
testContext.apiWrapper.object,
testContext.config.object,
testContext.queryRunner.object,
testContext.modelClient.object,
testContext.recentModels.object);
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.is(x => x.indexOf('INSERT INTO') > 0), TypeMoq.It.isAny())).returns(() => {
deployed = true;
return Promise.resolve(result);
});
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
return deployed ? Promise.resolve(updatedResult) : Promise.resolve(result);
});
testContext.queryRunner.setup(x => x.runWithDatabaseChange(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
testContext.config.setup(x => x.registeredModelDatabaseName).returns(() => 'db');
testContext.config.setup(x => x.registeredModelTableName).returns(() => 'table');
testContext.config.setup(x => x.registeredModelTableSchemaName).returns(() => 'dbo');
let tempFilePath: string = '';
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, 'test');
await should(service.deployLocalModel(tempFilePath, model, testContext.importTable)).resolved();
}
finally {
await utils.deleteFile(tempFilePath);
}
});
it('getConfigureQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo';
const expected = `
IF NOT EXISTS
( SELECT t.name, s.name
FROM sys.tables t join sys.schemas s on t.schema_id=t.schema_id
WHERE t.name = 'ta[b]le'
AND s.name = 'dbo'
)
BEGIN
CREATE TABLE [dbo].[ta[[b]]le](
[model_id] [int] IDENTITY(1,1) NOT NULL,
[model_name] [varchar](256) NOT NULL,
[model_framework] [varchar](256) NULL,
[model_framework_version] [varchar](256) NULL,
[model] [varbinary](max) NOT NULL,
[model_version] [varchar](256) NULL,
[model_creation_time] [datetime2] NULL,
[model_deployment_time] [datetime2] NULL,
[deployed_by] [int] NULL,
[model_description] [varchar](256) NULL,
[run_id] [varchar](256) NULL,
CONSTRAINT [ta[[b]]le_models_pk] PRIMARY KEY CLUSTERED
(
[model_id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY]
ALTER TABLE [dbo].[ta[[b]]le] ADD CONSTRAINT [ta[[b]]le_deployment_time] DEFAULT (getdate()) FOR [model_deployment_time]
END
`;
const actual = queries.getConfigureTableQuery(testContext.importTable);
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
});
it('getDeployedModelsQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
testContext.importTable.databaseName = 'd[]b';
testContext.importTable.tableName = 'ta[b]le';
testContext.importTable.schema = 'dbo';
const expected = `
SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id
FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml'
ORDER BY model_id
`;
const actual = queries.getDeployedModelsQuery(testContext.importTable);
should.deepEqual(expected, actual);
});
it('getInsertModelQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
version: '1.1',
table: testContext.importTable
};
const expected = `INSERT INTO [dbo].[tb]
(model_name, model, model_version, model_description, model_creation_time, model_framework, model_framework_version, run_id)
VALUES (
'name1',
,
'1.1',
'desc1',
'2018-01-01',
'',
'',
'')`;
const actual = queries.getInsertModelQuery(model, testContext.importTable);
should.equal(actual.indexOf(expected) >= 0, true, `actual: ${actual} \n expected: ${expected}`);
});
it('getModelContentQuery should escape db name', async function (): Promise<void> {
const testContext = createContext();
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
version: '1.1',
table: testContext.importTable
};
model.table = {
databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo'
};
const expected = `
SELECT model
FROM [d[[]]b].[dbo].[ta[[b]]le]
WHERE model_id = 1;
`;
const actual = queries.getModelContentQuery(model);
should.deepEqual(actual, expected, `actual: ${actual} \n expected: ${expected}`);
});
});

View File

@@ -0,0 +1,121 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { Config } from '../../configurations/config';
import * as utils from '../utils';
import { ProcessService } from '../../common/processService';
import { PackageManager } from '../../packageManagement/packageManager';
import { ModelPythonClient } from '../../modelManagement/modelPythonClient';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
config: TypeMoq.IMock<Config>;
outputChannel: vscode.OutputChannel;
op: azdata.BackgroundOperation;
processService: TypeMoq.IMock<ProcessService>;
packageManager: TypeMoq.IMock<PackageManager>;
}
function createContext(): TestContext {
const context = utils.createContext();
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
config: TypeMoq.Mock.ofType(Config),
outputChannel: context.outputChannel,
op: context.op,
processService: TypeMoq.Mock.ofType(ProcessService),
packageManager: TypeMoq.Mock.ofType(PackageManager)
};
}
describe('ModelPythonClient', () => {
it('deployModel should deploy the model successfully', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
const modelPath = 'C:\\test';
let service = new ModelPythonClient(
testContext.outputChannel,
testContext.apiWrapper.object,
testContext.processService.object,
testContext.config.object,
testContext.packageManager.object);
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
await service.deployModel(connection, modelPath);
});
it('loadModelParameters should load model parameters successfully', async function (): Promise<void> {
const testContext = createContext();
const modelPath = 'C:\\test';
const expected = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
},
]
};
const parametersJson = `
{
"inputs": [
{
"name": "p1",
"type": "int"
},
{
"name": "p2",
"type": "varchar"
}
],
"outputs": [
{
"name": "o1",
"type": "int"
}
]
}
`;
let service = new ModelPythonClient(
testContext.outputChannel,
testContext.apiWrapper.object,
testContext.processService.object,
testContext.config.object,
testContext.packageManager.object);
testContext.packageManager.setup(x => x.installRequiredPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.config.setup(x => x.pythonExecutable).returns(() => 'pythonPath');
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(),
TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(parametersJson));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
const actual = await service.loadModelParameters(modelPath);
should.deepEqual(actual, expected);
});
});

View File

@@ -0,0 +1,73 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { QueryRunner } from '../../common/queryRunner';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { PackageManagementService } from '../../packageManagement/packageManagementService';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
queryRunner: TypeMoq.IMock<QueryRunner>;
}
function createContext(): TestContext {
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
};
}
describe('Package Management Service', () => {
it('openDocuments should open document in browser successfully', async function (): Promise<void> {
const context = createContext();
context.apiWrapper.setup(x => x.openExternal(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
should.equal(await serverConfigManager.openDocuments(), true);
});
it('isMachineLearningServiceEnabled should return true if external script is enabled', async function (): Promise<void> {
const context = createContext();
context.queryRunner.setup(x => x.isMachineLearningServiceEnabled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
let connection = new azdata.connection.ConnectionProfile();
should.equal(await serverConfigManager.isMachineLearningServiceEnabled(connection), true);
});
it('isRInstalled should return true if R is installed', async function (): Promise<void> {
const context = createContext();
context.queryRunner.setup(x => x.isRInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
let connection = new azdata.connection.ConnectionProfile();
should.equal(await serverConfigManager.isRInstalled(connection), true);
});
it('isPythonInstalled should return true if Python is installed', async function (): Promise<void> {
const context = createContext();
context.queryRunner.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
let connection = new azdata.connection.ConnectionProfile();
should.equal(await serverConfigManager.isPythonInstalled(connection), true);
});
it('enableExternalScriptConfig should show error message if did not updated successfully', async function (): Promise<void> {
const context = createContext();
context.queryRunner.setup(x => x.updateExternalScriptConfig(TypeMoq.It.isAny(), true)).returns(() => Promise.resolve());
context.queryRunner.setup(x => x.isMachineLearningServiceEnabled(TypeMoq.It.isAny())).returns(() => Promise.resolve(false));
context.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
context.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
context.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
let serverConfigManager = new PackageManagementService(context.apiWrapper.object, context.queryRunner.object);
let connection = new azdata.connection.ConnectionProfile();
await serverConfigManager.enableExternalScriptConfig(connection);
context.apiWrapper.verify(x => x.showErrorMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
});

View File

@@ -0,0 +1,273 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import { PackageManager } from '../../packageManagement/packageManager';
import { createContext, TestContext } from './utils';
describe('Package Manager', () => {
it('Should initialize SQL package manager successfully', async function (): Promise<void> {
let testContext = createContext();
should.doesNotThrow(() => createPackageManager(testContext));
});
it('Manage Package command Should execute the command for valid connection', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Manage Package command Should execute the command if r installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Manage Package command Should show an error for connection without python installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny()));
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isRInstalled(connection)).returns(() => {return Promise.resolve(false);});
testContext.serverConfigManager.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Manage Package command Should show an error for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny()));
testContext.serverConfigManager.setup(x => x.enableExternalScriptConfig(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('installDependencies Should download sqlmlutils if does not exist', async function (): Promise<void> {
let testContext = createContext();
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"},
{"name":"sqlmlutils","version":"1.1.1"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
testContext.httpClient.verify(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('installDependencies Should not install packages if already installed', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"},
{"name":"sqlmlutils","version":"1.1.1"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
}
return Promise.resolve(installedPackages);
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, false);
});
it('installDependencies Should install packages that are not already installed', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"}
]`;
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
}
return Promise.resolve(installedPackages);
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, true);
});
it('installDependencies Should not install packages if runtime is disabled in setting', async function (): Promise<void> {
let testContext = createContext();
testContext.config.setup(x => x.rEnabled).returns(() => false);
testContext.config.setup(x => x.pythonEnabled).returns(() => false);
let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"}
]`;
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip install') > 0 || command.indexOf('install.packages') > 0) {
packagesInstalled = true;
}
return Promise.resolve(installedPackages);
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, false);
});
it('installDependencies Should install packages that have older version installed', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
let installedPackages = `[
{"name":"sqlmlutils","version":"0.1.1"}
]`;
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
}
return Promise.resolve(installedPackages);
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, true);
});
it('installDependencies Should install packages if list packages fails', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({
label: 'Yes'
}));
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command,) => {
if (command.indexOf('pip list') > 0) {
return Promise.reject();
} else if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
return Promise.resolve('');
} else {
return Promise.resolve('');
}
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, true);
});
it('installDependencies Should fail if download packages fails', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject());
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip list') > 0) {
return Promise.resolve(installedPackages);
} else if (command.indexOf('pip install') > 0) {
return Promise.reject();
} else {
return Promise.resolve('');
}
});
let packageManager = createPackageManager(testContext);
await should(packageManager.installDependencies()).rejected();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Failed);
should.equal(packagesInstalled, false);
});
function createPackageManager(testContext: TestContext): PackageManager {
testContext.config.setup(x => x.requiredSqlPythonPackages).returns( () => [
{ name: 'pymssql', version: '2.1.4' },
{ name: 'sqlmlutils', version: '' }
]);
testContext.config.setup(x => x.requiredSqlRPackages).returns( () => [
{ name: 'RODBCext', repository: 'https://cran.microsoft.com' },
{ name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'}
]);
testContext.httpClient.setup(x => x.download(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
testContext.config.setup(x => x.rEnabled).returns(() => true);
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
let packageManager = new PackageManager(
testContext.outputChannel,
'',
testContext.apiWrapper.object,
testContext.serverConfigManager.object,
testContext.processService.object,
testContext.config.object,
testContext.httpClient.object);
packageManager.init();
packageManager.dependenciesInstalled = true;
return packageManager;
}
});

View File

@@ -0,0 +1,399 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import { SqlPythonPackageManageProvider } from '../../packageManagement/sqlPythonPackageManageProvider';
import { createContext, TestContext } from './utils';
import * as nbExtensionApis from '../../typings/notebookServices';
describe('SQL Python Package Manager', () => {
it('Should create SQL package manager successfully', async function (): Promise<void> {
let testContext = createContext();
should.doesNotThrow(() => createProvider(testContext));
});
it('Should return provider Id and target correctly', async function (): Promise<void> {
let testContext = createContext();
let provider = createProvider(testContext);
should.deepEqual(SqlPythonPackageManageProvider.ProviderId, provider.providerId);
should.deepEqual({ location: 'SQL', packageType: 'Python' }, provider.packageTarget);
});
it('listPackages Should return packages sorted by name', async function (): Promise<void> {
let testContext = createContext();
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'b-name',
'version': '1.1.1'
},
{
'name': 'a-name',
'version': '1.1.2'
}
];
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
let provider = createProvider(testContext);
let actual = await provider.listPackages(connection.databaseName);
let expected = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
should.deepEqual(actual, expected);
});
it('listPackages Should return packages sorted by name and version', async function (): Promise<void> {
let testContext = createContext();
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'b-name',
'version': '1.1.1'
},
{
'name': 'b-name',
'version': '1.1.2'
}
];
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
let provider = createProvider(testContext);
let actual = await provider.listPackages(connection.databaseName);
let expected = [
{
'name': 'b-name',
'version': '1.1.1'
},
{
'name': 'b-name',
'version': '1.1.2'
}
];
should.deepEqual(actual, expected);
});
it('listPackages Should return empty packages if undefined packages returned', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
let packages: nbExtensionApis.IPackageDetails[];
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
let provider = createProvider(testContext);
let actual = await provider.listPackages(connection.databaseName);
let expected: nbExtensionApis.IPackageDetails[] = [];
should.deepEqual(actual, expected);
});
it('listPackages Should return empty packages if empty packages returned', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve([]));
let provider = createProvider(testContext);
let actual = await provider.listPackages(connection.databaseName);
let expected: nbExtensionApis.IPackageDetails[] = [];
should.deepEqual(actual, expected);
});
it('installPackages Should install given packages successfully', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('install') > 0) &&
scripts.find(x => x.indexOf('port=1433') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('package="a-name"') > 0) &&
scripts.find(x => x.indexOf('version="1.1.2"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false, connection.databaseName);
should.deepEqual(packagesUpdated, true);
});
it('uninstallPackages Should uninstall given packages successfully', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('uninstall') > 0) &&
scripts.find(x => x.indexOf('port=1433') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('package_name="a-name"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.uninstallPackages(packages, connection.databaseName);
should.deepEqual(packagesUpdated, true);
});
it('installPackages Should include port name in the script', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName,3433';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('install') > 0) &&
scripts.find(x => x.indexOf('port=3433') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('package="a-name"') > 0) &&
scripts.find(x => x.indexOf('version="1.1.2"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false, connection.databaseName);
should.deepEqual(packagesUpdated, true);
});
it('installPackages Should not install any packages give empty list', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
];
let connection = new azdata.connection.ConnectionProfile();
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
packagesUpdated = true;
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false, connection.databaseName);
should.deepEqual(packagesUpdated, false);
});
it('uninstallPackages Should not uninstall any packages give empty list', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
];
let connection = new azdata.connection.ConnectionProfile();
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
packagesUpdated = true;
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.uninstallPackages(packages, connection.databaseName);
should.deepEqual(packagesUpdated, false);
});
it('canUseProvider Should return false for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('canUseProvider Should return false if connection does not have python installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(false));
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('canUseProvider Should return true if connection has python installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, true);
});
it('canUseProvider Should return false if python is disabled in setting', async function (): Promise<void> {
let testContext = createContext();
let provider = createProvider(testContext);
testContext.config.setup(x => x.pythonEnabled).returns(() => false);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('getPackageOverview Should return package info using python packages provider', async function (): Promise<void> {
let testContext = createContext();
let packagePreview = {
name: 'package name',
versions: ['0.0.2', '0.0.1'],
summary: 'package summary'
};
testContext.httpClient.setup(x => x.fetch(TypeMoq.It.isAny())).returns(() => {
return Promise.resolve(`{"info":{"summary":"package summary"}, "releases":{"0.0.1":[{"comment_text":""}], "0.0.2":[{"comment_text":""}]}}`);
});
let provider = createProvider(testContext);
let actual = await provider.getPackageOverview('package name');
should.deepEqual(actual, packagePreview);
});
it('getLocations Should return empty array for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.getLocations();
should.deepEqual(actual, []);
});
it('getLocations Should return database names for valid connection', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
const databaseNames = [
'db1',
'db2'
];
const expected = [
{
displayName: 'db1',
name: 'db1'
},
{
displayName: 'db2',
name: 'db2'
}
];
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.listDatabases(connection.connectionId)).returns(() => { return Promise.resolve(databaseNames); });
let provider = createProvider(testContext);
let actual = await provider.getLocations();
should.deepEqual(actual, expected);
});
function createProvider(testContext: TestContext): SqlPythonPackageManageProvider {
testContext.config.setup(x => x.pythonExecutable).returns(() => 'python');
testContext.config.setup(x => x.pythonEnabled).returns(() => true);
return new SqlPythonPackageManageProvider(
testContext.outputChannel,
testContext.apiWrapper.object,
testContext.serverConfigManager.object,
testContext.processService.object,
testContext.config.object,
testContext.httpClient.object);
}
});

View File

@@ -0,0 +1,325 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import { SqlRPackageManageProvider } from '../../packageManagement/sqlRPackageManageProvider';
import { createContext, TestContext } from './utils';
import * as nbExtensionApis from '../../typings/notebookServices';
describe('SQL R Package Manager', () => {
it('Should create SQL package manager successfully', async function (): Promise<void> {
let testContext = createContext();
should.doesNotThrow(() => createProvider(testContext));
});
it('Should return provider Id and target correctly', async function (): Promise<void> {
let testContext = createContext();
let provider = createProvider(testContext);
should.deepEqual(SqlRPackageManageProvider.ProviderId, provider.providerId);
should.deepEqual({ location: 'SQL', packageType: 'R' }, provider.packageTarget);
});
it('listPackages Should return packages sorted by name', async function (): Promise<void> {
let testContext = createContext();
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'b-name',
'version': '1.1.1'
},
{
'name': 'a-name',
'version': '1.1.2'
}
];
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
let provider = createProvider(testContext);
let actual = await provider.listPackages(connection.databaseName);
let expected = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
should.deepEqual(actual, expected);
});
it('listPackages Should return empty packages if undefined packages returned', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
let packages: nbExtensionApis.IPackageDetails[];
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
let provider = createProvider(testContext);
let actual = await provider.listPackages(connection.databaseName);
let expected: nbExtensionApis.IPackageDetails[] = [];
should.deepEqual(actual, expected);
});
it('listPackages Should return empty packages if empty packages returned', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve([]));
let provider = createProvider(testContext);
let actual = await provider.listPackages(connection.databaseName);
let expected: nbExtensionApis.IPackageDetails[] = [];
should.deepEqual(actual, expected);
});
it('installPackages Should install given packages successfully', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('install') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('"a-name"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false, connection.databaseName);
should.deepEqual(packagesUpdated, true);
});
it('uninstallPackages Should uninstall given packages successfully', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('remove') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('"a-name"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.uninstallPackages(packages, connection.databaseName);
should.deepEqual(packagesUpdated, true);
});
it('installPackages Should not install any packages give empty list', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
];
let connection = new azdata.connection.ConnectionProfile();
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
packagesUpdated = true;
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false, connection.databaseName);
should.deepEqual(packagesUpdated, false);
});
it('uninstallPackages Should not uninstall any packages give empty list', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
];
let connection = new azdata.connection.ConnectionProfile();
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
packagesUpdated = true;
return Promise.resolve('');
});
let provider = createProvider(testContext);
await provider.uninstallPackages(packages, connection.databaseName);
should.deepEqual(packagesUpdated, false);
});
it('canUseProvider Should return false for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('canUseProvider Should return false if connection does not have r installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.isRInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(false));
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('canUseProvider Should return true if connection has r installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.serverConfigManager.setup(x => x.isRInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, true);
});
it('canUseProvider Should return false if r is disabled in setting', async function (): Promise<void> {
let testContext = createContext();
let provider = createProvider(testContext);
testContext.config.setup(x => x.rEnabled).returns(() => false);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('getPackageOverview Should return package info successfully', async function (): Promise<void> {
let testContext = createContext();
let packagePreview = {
'name': 'a-name',
'versions': ['Latest'],
'summary': ''
};
testContext.httpClient.setup(x => x.fetch(TypeMoq.It.isAny())).returns(() => {
return Promise.resolve(``);
});
let provider = createProvider(testContext);
let actual = await provider.getPackageOverview('a-name');
should.deepEqual(actual, packagePreview);
});
it('getLocations Should return empty array for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.getLocations();
should.deepEqual(actual, []);
});
it('getLocations Should return database names for valid connection', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
const databaseNames = [
'db1',
'db2'
];
const expected = [
{
displayName: 'db1',
name: 'db1'
},
{
displayName: 'db2',
name: 'db2'
}
];
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.listDatabases(connection.connectionId)).returns(() => { return Promise.resolve(databaseNames); });
let provider = createProvider(testContext);
let actual = await provider.getLocations();
should.deepEqual(actual, expected);
});
function createProvider(testContext: TestContext): SqlRPackageManageProvider {
testContext.config.setup(x => x.rExecutable).returns(() => 'r');
testContext.config.setup(x => x.rEnabled).returns(() => true);
testContext.config.setup(x => x.rPackagesRepository).returns(() => 'http://cran.r-project.org');
return new SqlRPackageManageProvider(
testContext.outputChannel,
testContext.apiWrapper.object,
testContext.serverConfigManager.object,
testContext.processService.object,
testContext.config.object,
testContext.httpClient.object);
}
});

View File

@@ -0,0 +1,45 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../common/apiWrapper';
import { QueryRunner } from '../../common/queryRunner';
import { ProcessService } from '../../common/processService';
import { Config } from '../../configurations/config';
import { HttpClient } from '../../common/httpClient';
import * as utils from '../utils';
import { PackageManagementService } from '../../packageManagement/packageManagementService';
export interface TestContext {
outputChannel: vscode.OutputChannel;
processService: TypeMoq.IMock<ProcessService>;
apiWrapper: TypeMoq.IMock<ApiWrapper>;
queryRunner: TypeMoq.IMock<QueryRunner>;
config: TypeMoq.IMock<Config>;
op: azdata.BackgroundOperation;
getOpStatus: () => azdata.TaskStatus;
httpClient: TypeMoq.IMock<HttpClient>;
serverConfigManager: TypeMoq.IMock<PackageManagementService>;
}
export function createContext(): TestContext {
const context = utils.createContext();
return {
outputChannel: context.outputChannel,
processService: TypeMoq.Mock.ofType(ProcessService),
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
config: TypeMoq.Mock.ofType(Config),
httpClient: TypeMoq.Mock.ofType(HttpClient),
op: context.op,
getOpStatus: context.getOpStatus,
serverConfigManager: TypeMoq.Mock.ofType(PackageManagementService)
};
}

View File

@@ -0,0 +1,301 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ApiWrapper } from '../../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { PredictService } from '../../prediction/predictService';
import { QueryRunner } from '../../common/queryRunner';
import { ImportedModel } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import * as path from 'path';
import * as os from 'os';
import * as UUID from 'vscode-languageclient/lib/utils/uuid';
import * as fs from 'fs';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
importTable: DatabaseTable;
queryRunner: TypeMoq.IMock<QueryRunner>;
}
function createContext(): TestContext {
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
importTable: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
},
queryRunner: TypeMoq.Mock.ofType(QueryRunner)
};
}
describe('PredictService', () => {
it('getDatabaseList should return databases successfully', async function (): Promise<void> {
const testContext = createContext();
const expected: string[] = [
'db1',
'db2'
];
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.listDatabases(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(expected); });
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object);
const actual = await service.getDatabaseList();
should.deepEqual(actual, expected);
});
it('getTableList should return tables successfully', async function (): Promise<void> {
const testContext = createContext();
const expected: DatabaseTable[] = [
{
databaseName: 'db1',
schema: 'dbo',
tableName: 'tb1'
},
{
databaseName: 'db1',
tableName: 'tb2',
schema: 'dbo'
}
];
const result = {
rowCount: 1,
columnInfo: [],
rows: [[
{
displayValue: 'tb1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'dbo',
isNull: false,
invariantCultureDisplayValue: ''
}
], [
{
displayValue: 'tb2',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'dbo',
isNull: false,
invariantCultureDisplayValue: ''
}
]]
};
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object);
const actual = await service.getTableList('db1');
should.deepEqual(actual, expected);
});
it('getTableColumnsList should return table columns successfully', async function (): Promise<void> {
const testContext = createContext();
const expected: TableColumn[] = [
{
columnName: 'c1',
dataType: 'int'
},
{
columnName: 'c2',
dataType: 'varchar'
}
];
const table: DatabaseTable =
{
databaseName: 'db1',
schema: 'dbo',
tableName: 'tb1'
};
const result = {
rowCount: 1,
columnInfo: [],
rows: [[
{
displayValue: 'c1',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'int',
isNull: false,
invariantCultureDisplayValue: ''
}
], [
{
displayValue: 'c2',
isNull: false,
invariantCultureDisplayValue: ''
},
{
displayValue: 'varchar',
isNull: false,
invariantCultureDisplayValue: ''
}
]]
};
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.queryRunner.setup(x => x.safeRunQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(result));
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object);
const actual = await service.getTableColumnsList(table);
should.deepEqual(actual, expected);
});
it('generatePredictScript should generate the script successfully using model', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const predictParams: PredictParameters = {
inputColumns: [
{
paramName: 'p1',
dataType: 'int',
columnName: ''
},
{
paramName: 'p2',
dataType: 'varchar',
columnName: ''
}
],
outputColumns: [
{
paramName: 'o1',
dataType: 'int',
columnName: ''
},
],
databaseName: '',
tableName: '',
schema: ''
};
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
version: '1.1',
table: testContext.importTable
};
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object);
const document: vscode.TextDocument = {
uri: vscode.Uri.parse('file:///usr/home'),
fileName: '',
isUntitled: true,
languageId: 'sql',
version: 1,
isDirty: true,
isClosed: false,
save: undefined!,
eol: undefined!,
lineCount: 1,
lineAt: undefined!,
offsetAt: undefined!,
positionAt: undefined!,
getText: undefined!,
getWordRangeAtPosition: undefined!,
validateRange: undefined!,
validatePosition: undefined!
};
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
const actual = await service.generatePredictScript(predictParams, model, undefined);
should.notEqual(actual, undefined);
should.equal(actual.indexOf('FROM PREDICT(MODEL = @model') > 0, true);
});
it('generatePredictScript should generate the script successfully using file', async function (): Promise<void> {
const testContext = createContext();
const connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
const predictParams: PredictParameters = {
inputColumns: [
{
paramName: 'p1',
dataType: 'int',
columnName: ''
},
{
paramName: 'p2',
dataType: 'varchar',
columnName: ''
}
],
outputColumns: [
{
paramName: 'o1',
dataType: 'int',
columnName: ''
},
],
databaseName: '',
tableName: '',
schema: ''
};
const tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${UUID.generateUuid()}`);
await fs.promises.writeFile(tempFilePath, 'test');
let service = new PredictService(
testContext.apiWrapper.object,
testContext.queryRunner.object);
const document: vscode.TextDocument = {
uri: vscode.Uri.parse('file:///usr/home'),
fileName: '',
isUntitled: true,
languageId: 'sql',
version: 1,
isDirty: true,
isClosed: false,
save: undefined!,
eol: undefined!,
lineCount: 1,
lineAt: undefined!,
offsetAt: undefined!,
positionAt: undefined!,
getText: undefined!,
getWordRangeAtPosition: undefined!,
validateRange: undefined!,
validatePosition: undefined!
};
testContext.apiWrapper.setup(x => x.openTextDocument(TypeMoq.It.isAny())).returns(() => Promise.resolve(document));
testContext.apiWrapper.setup(x => x.connect(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.apiWrapper.setup(x => x.runQuery(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { });
const actual = await service.generatePredictScript(predictParams, undefined, tempFilePath);
should.notEqual(actual, undefined);
should.equal(actual.indexOf('FROM PREDICT(MODEL = 0X') > 0, true);
});
});

View File

@@ -0,0 +1,303 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { QueryRunner } from '../common/queryRunner';
import { IPackageDetails } from '../typings/notebookServices';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
queryProvider: azdata.QueryProvider;
}
function createContext(): TestContext {
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryProvider: {
providerId: '',
cancelQuery: () => {return Promise.reject();},
runQuery: () => {return Promise.reject();},
runQueryStatement: () => {return Promise.reject();},
runQueryString: () => {return Promise.reject();},
runQueryAndReturn: () => { return Promise.reject(); },
parseSyntax: () => {return Promise.reject();},
getQueryRows: () => {return Promise.reject();},
disposeQuery: () => {return Promise.reject();},
saveResults: () => {return Promise.reject();},
setQueryExecutionOptions: () => {return Promise.reject();},
registerOnQueryComplete: () => {return Promise.reject();},
registerOnBatchStart: () => {return Promise.reject();},
registerOnBatchComplete: () => {return Promise.reject();},
registerOnResultSetAvailable: () => {return Promise.reject();},
registerOnResultSetUpdated: () => {return Promise.reject();},
registerOnMessage: () => {return Promise.reject();},
commitEdit: () => {return Promise.reject();},
createRow: () => {return Promise.reject();},
deleteRow: () => {return Promise.reject();},
disposeEdit: () => {return Promise.reject();},
initializeEdit: () => {return Promise.reject();},
revertCell: () => {return Promise.reject();},
revertRow: () => {return Promise.reject();},
updateCell: () => {return Promise.reject();},
getEditRows: () => {return Promise.reject();},
registerOnEditSessionReady: () => {return Promise.reject();},
}
};
}
describe('Query Runner', () => {
it('getPythonPackages Should return empty list if not provider found', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
let queryProvider: azdata.QueryProvider;
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => queryProvider);
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
should.deepEqual(actual, []);
});
it('getPythonPackages Should return empty list if not provider throws', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.reject(); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
should.deepEqual(actual, []);
});
it('getPythonPackages Should return list if provider runs the query successfully', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: 'p1',
isNull: false,
invariantCultureDisplayValue: ''
}, {
displayValue: '1.1.1',
isNull: false,
invariantCultureDisplayValue: ''
}],
[{
displayValue: 'p2',
isNull: false,
invariantCultureDisplayValue: ''
}, {
displayValue: '1.1.2',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = [
{
'name': 'p1',
'version': '1.1.1'
},
{
'name': 'p2',
'version': '1.1.2'
}
];
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
should.deepEqual(actual, expected);
});
it('getPythonPackages Should return empty list if provider return no rows', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
];
let expected: IPackageDetails[] = [];
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.getPythonPackages(connection, connection.databaseName);
should.deepEqual(actual, expected);
});
it('updateExternalScriptConfig Should update config successfully', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
];
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
await should(queryRunner.updateExternalScriptConfig(connection, true)).resolved();
});
it('isPythonInstalled Should return true is provider returns valid result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = true;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isPythonInstalled(connection);
should.deepEqual(actual, expected);
});
it('isPythonInstalled Should return true is provider returns 0 as result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '0',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isPythonInstalled(connection);
should.deepEqual(actual, expected);
});
it('isPythonInstalled Should return false is provider returns no result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isPythonInstalled(connection);
should.deepEqual(actual, expected);
});
it('isMachineLearningServiceEnabled Should return true is provider returns valid result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = true;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
should.deepEqual(actual, expected);
});
it('isMachineLearningServiceEnabled Should return true is provider returns 0 as result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '0',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
should.deepEqual(actual, expected);
});
it('isMachineLearningServiceEnabled Should return false is provider returns no result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
should.deepEqual(actual, expected);
});
});

View File

@@ -0,0 +1,38 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
export interface TestContext {
outputChannel: vscode.OutputChannel;
op: azdata.BackgroundOperation;
getOpStatus: () => azdata.TaskStatus;
}
export function createContext(): TestContext {
let opStatus: azdata.TaskStatus;
return {
outputChannel: {
name: '',
append: () => { },
appendLine: () => { },
clear: () => { },
show: () => { },
hide: () => { },
dispose: () => { }
},
op: {
updateStatus: (status: azdata.TaskStatus) => {
opStatus = status;
},
id: '',
onCanceled: new vscode.EventEmitter<void>().event,
},
getOpStatus: () => { return opStatus; }
};
}

View File

@@ -0,0 +1,39 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../common/apiWrapper';
import { createViewContext } from './utils';
import { DashboardWidget } from '../../views/widgets/dashboardWidget';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
view: azdata.ModelView;
onClick: vscode.EventEmitter<any>;
}
function createContext(): TestContext {
let viewTestContext = createViewContext();
return {
apiWrapper: viewTestContext.apiWrapper,
view: viewTestContext.view,
onClick: viewTestContext.onClick
};
}
describe('Dashboard widget', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
const dashboard = new DashboardWidget(testContext.apiWrapper.object, '');
dashboard.register();
testContext.onClick.fire(undefined);
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce());
});
});

View File

@@ -0,0 +1,120 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import { createContext, ParentDialog } from './utils';
import { AddEditLanguageTab } from '../../../views/externalLanguages/addEditLanguageTab';
import { LanguageUpdateModel } from '../../../views/externalLanguages/languageViewBase';
describe('Add Edit External Languages Tab', () => {
it('Should create AddEditLanguageTab for new language successfully ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let languageUpdateModel: LanguageUpdateModel = {
content: parent.createNewContent(),
language: parent.createNewLanguage(),
newLang: true
};
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
should.notEqual(tab.languageView, undefined, 'Failed to create language view for add');
});
it('Should create AddEditLanguageTab for edit successfully ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let languageUpdateModel: LanguageUpdateModel = {
content: {
extensionFileName: 'filename',
isLocalFile: true,
pathToExtension: 'path',
},
language: {
name: 'name',
contents: []
},
newLang: false
};
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
should.notEqual(tab.languageView, undefined, 'Failed to create language view for edit');
should.equal(tab.saveButton, undefined);
});
it('Should reset AddEditLanguageTab successfully ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let languageUpdateModel: LanguageUpdateModel = {
content: {
extensionFileName: 'filename',
isLocalFile: true,
pathToExtension: 'path',
},
language: {
name: 'name',
contents: []
},
newLang: false
};
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
if (tab.languageName) {
tab.languageName.value = 'some value';
}
await tab.reset();
should.equal(tab.languageName?.value, 'name');
});
it('Should load content successfully ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let languageUpdateModel: LanguageUpdateModel = {
content: {
extensionFileName: 'filename',
isLocalFile: true,
pathToExtension: 'path',
environmentVariables: 'env vars',
parameters: 'params'
},
language: {
name: 'name',
contents: []
},
newLang: false
};
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
let content = tab.languageView?.updatedContent;
should.notEqual(content, undefined);
if (content) {
should.equal(content.extensionFileName, languageUpdateModel.content.extensionFileName);
should.equal(content.pathToExtension, languageUpdateModel.content.pathToExtension);
should.equal(content.environmentVariables, languageUpdateModel.content.environmentVariables);
should.equal(content.parameters, languageUpdateModel.content.parameters);
}
});
it('Should raise save event if save button clicked ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let languageUpdateModel: LanguageUpdateModel = {
content: parent.createNewContent(),
language: parent.createNewLanguage(),
newLang: true
};
let tab = new AddEditLanguageTab(testContext.apiWrapper.object, parent, languageUpdateModel);
should.notEqual(tab.saveButton, undefined);
let updateCalled = false;
let promise = new Promise(resolve => {
parent.onUpdate(() => {
updateCalled = true;
resolve();
});
});
testContext.onClick.fire(undefined);
parent.onUpdatedLanguage(languageUpdateModel);
await promise;
should.equal(updateCalled, true);
should.notEqual(tab.updatedData, undefined);
});
});

View File

@@ -0,0 +1,104 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import { createContext } from './utils';
import { LanguageController } from '../../../views/externalLanguages/languageController';
import * as mssql from '../../../../../mssql';
describe('External Languages Controller', () => {
it('Should open dialog for manage languages successfully ', async function (): Promise<void> {
let testContext = createContext();
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
let dialog = await controller.manageLanguages();
testContext.apiWrapper.verify(x => x.openDialog(TypeMoq.It.isAny()), TypeMoq.Times.once());
should.notEqual(dialog, undefined);
});
it('Should list languages successfully ', async function (): Promise<void> {
let testContext = createContext();
let languages: mssql.ExternalLanguage[] = [{
name: '',
contents: [{
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
}]
}];
testContext.dialogModel.setup( x=> x.getLanguageList()).returns(() => Promise.resolve(languages));
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
let dialog = await controller.manageLanguages();
let actual = await dialog.listLanguages();
should.deepEqual(actual, languages);
});
it('Should update languages successfully ', async function (): Promise<void> {
let testContext = createContext();
let language: mssql.ExternalLanguage = {
name: '',
contents: [{
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
}]
};
testContext.dialogModel.setup( x=> x.updateLanguage(language)).returns(() => Promise.resolve());
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
let dialog = await controller.manageLanguages();
await dialog.updateLanguage({
language: language,
content: language.contents[0],
newLang: false
});
testContext.dialogModel.verify(x => x.updateLanguage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Should delete language successfully ', async function (): Promise<void> {
let testContext = createContext();
let language: mssql.ExternalLanguage = {
name: '',
contents: [{
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
}]
};
testContext.dialogModel.setup( x=> x.deleteLanguage(language.name)).returns(() => Promise.resolve());
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
let dialog = await controller.manageLanguages();
await dialog.deleteLanguage({
language: language,
content: language.contents[0],
newLang: false
});
testContext.dialogModel.verify(x => x.deleteLanguage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Should open edit dialog for edit language', async function (): Promise<void> {
let testContext = createContext();
let language: mssql.ExternalLanguage = {
name: '',
contents: [{
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
}]
};
let controller = new LanguageController(testContext.apiWrapper.object, '', testContext.dialogModel.object);
let dialog = await controller.manageLanguages();
dialog.onEditLanguage({
language: language,
content: language.contents[0],
newLang: false
});
testContext.apiWrapper.verify(x => x.openDialog(TypeMoq.It.isAny()), TypeMoq.Times.exactly(2));
should.notEqual(dialog, undefined);
});
});

View File

@@ -0,0 +1,50 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import { createContext, ParentDialog } from './utils';
import { LanguageEditDialog } from '../../../views/externalLanguages/languageEditDialog';
import { LanguageUpdateModel } from '../../../views/externalLanguages/languageViewBase';
describe('Edit External Languages Dialog', () => {
it('Should open dialog successfully ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let languageUpdateModel: LanguageUpdateModel = {
content: parent.createNewContent(),
language: parent.createNewLanguage(),
newLang: true
};
let dialog = new LanguageEditDialog(testContext.apiWrapper.object, parent, languageUpdateModel);
dialog.showDialog();
should.notEqual(dialog.addNewLanguageTab, undefined);
});
it('Should raise save event if save button clicked ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let languageUpdateModel: LanguageUpdateModel = {
content: parent.createNewContent(),
language: parent.createNewLanguage(),
newLang: true
};
let dialog = new LanguageEditDialog(testContext.apiWrapper.object, parent, languageUpdateModel);
dialog.showDialog();
let updateCalled = false;
let promise = new Promise(resolve => {
parent.onUpdate(() => {
updateCalled = true;
parent.onUpdatedLanguage(languageUpdateModel);
resolve();
});
});
dialog.onSave();
await promise;
should.equal(updateCalled, true);
});
});

View File

@@ -0,0 +1,19 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { LanguagesDialog } from '../../../views/externalLanguages/languagesDialog';
describe('External Languages Dialog', () => {
it('Should open dialog successfully ', async function (): Promise<void> {
let testContext = createContext();
let dialog = new LanguagesDialog(testContext.apiWrapper.object, '');
dialog.showDialog();
should.notEqual(dialog.addNewLanguageTab, undefined);
should.notEqual(dialog.currentLanguagesTab, undefined);
});
});

View File

@@ -0,0 +1,61 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import * as mssql from '../../../../../mssql';
import { LanguageService } from '../../../externalLanguage/languageService';
describe('External Languages Dialog Model', () => {
it('Should list languages successfully ', async function (): Promise<void> {
let testContext = createContext();
let languages: mssql.ExternalLanguage[] = [{
name: '',
contents: [{
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
}]
}];
testContext.languageExtensionService.listLanguages = () => {return Promise.resolve(languages);};
let model = new LanguageService(testContext.apiWrapper.object, testContext.languageExtensionService);
await model.load();
let actual = await model.getLanguageList();
should.deepEqual(actual, languages);
});
it('Should update language successfully ', async function (): Promise<void> {
let testContext = createContext();
let language: mssql.ExternalLanguage = {
name: '',
contents: [{
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
}]
};
let model = new LanguageService(testContext.apiWrapper.object, testContext.languageExtensionService);
await model.load();
await should(model.updateLanguage(language)).resolved();
});
it('Should delete language successfully ', async function (): Promise<void> {
let testContext = createContext();
let language: mssql.ExternalLanguage = {
name: '',
contents: [{
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
}]
};
let model = new LanguageService(testContext.apiWrapper.object, testContext.languageExtensionService);
await model.load();
await should(model.deleteLanguage(language.name)).resolved();
});
});

View File

@@ -0,0 +1,53 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../../common/apiWrapper';
import { LanguageViewBase } from '../../../views/externalLanguages/languageViewBase';
import * as mssql from '../../../../../mssql';
import { LanguageService } from '../../../externalLanguage/languageService';
import { createViewContext } from '../utils';
export interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
view: azdata.ModelView;
languageExtensionService: mssql.ILanguageExtensionService;
onClick: vscode.EventEmitter<any>;
dialogModel: TypeMoq.IMock<LanguageService>;
}
export class ParentDialog extends LanguageViewBase {
public reset(): Promise<void> {
return Promise.resolve();
}
constructor(
apiWrapper: ApiWrapper) {
super(apiWrapper, '');
}
}
export function createContext(): TestContext {
let viewTestContext = createViewContext();
let connection = new azdata.connection.ConnectionProfile();
viewTestContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
viewTestContext.apiWrapper.setup(x => x.getUriForConnection(TypeMoq.It.isAny())).returns(() => { return Promise.resolve('connectionUrl'); });
let languageExtensionService: mssql.ILanguageExtensionService = {
listLanguages: () => { return Promise.resolve([]); },
deleteLanguage: () => { return Promise.resolve(); },
updateLanguage: () => { return Promise.resolve(); }
};
return {
apiWrapper: viewTestContext.apiWrapper,
view: viewTestContext.view,
languageExtensionService: languageExtensionService,
onClick: viewTestContext.onClick,
dialogModel: TypeMoq.Mock.ofType(LanguageService)
};
}

View File

@@ -0,0 +1,202 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import * as TypeMoq from 'typemoq';
import 'mocha';
import { createContext } from './utils';
import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { ModelManagementController } from '../../../views/models/modelManagementController';
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
import { DeleteModelEventName, UpdateModelEventName } from '../../../views/models/modelViewBase';
import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog';
const accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
const subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
const groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3'
}
];
const workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
const models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
const localModels: ImportedModel[] = [
{
id: 1,
modelName: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
const dbNames: string[] = [
'db1',
'db2'
];
const tableNames: DatabaseTable[] = [
{
databaseName: 'db1',
schema: 'dbo',
tableName: 'tb1'
},
{
databaseName: 'db1',
tableName: 'tb2',
schema: 'dbo'
}
];
const columnNames: TableColumn[] = [
{
columnName: 'c1',
dataType: 'int'
},
{
columnName: 'c2',
dataType: 'varchar'
}
];
const modelParameters: ModelParameters = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
}
]
};
describe('Model Controller', () => {
it('Should open deploy model wizard successfully ', async function (): Promise<void> {
let testContext = createContext();
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
databaseName: 'db',
tableName: 'table',
schema: 'dbo'
}));
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames));
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames));
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
const view = await controller.registerModel(undefined);
should.notEqual(view, undefined);
});
it('Should open predict wizard successfully ', async function (): Promise<void> {
let testContext = createContext();
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
testContext.deployModelService.setup(x => x.getRecentImportTable()).returns(() => Promise.resolve({
databaseName: 'db',
tableName: 'table',
schema: 'dbo'
}));
testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels));
testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve([
'db', 'db1'
]));
testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve([
{ tableName: 'tb', databaseName: 'db', schema: 'dbo' }
]));
testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts));
testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions));
testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups));
testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces));
testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models));
testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames));
testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters));
testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file'));
const view = await controller.predictModel();
should.notEqual(view, undefined);
});
it('Should open edit model dialog successfully ', async function (): Promise<void> {
let testContext = createContext();
testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
testContext.deployModelService.setup(x => x.deleteModel(TypeMoq.It.isAny())).returns(() => Promise.resolve());
let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object);
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
version: '1.1',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
};
const view = <EditModelDialog>await controller.editModel(model);
should.notEqual(view?.editModelPage, undefined);
if (view.editModelPage) {
view.editModelPage.sendRequest(UpdateModelEventName, model);
view.editModelPage.sendRequest(DeleteModelEventName, model);
}
testContext.deployModelService.verify(x => x.updateModel(model), TypeMoq.Times.atLeastOnce());
testContext.deployModelService.verify(x => x.deleteModel(model), TypeMoq.Times.atLeastOnce());
should.notEqual(view, undefined);
});
});

View File

@@ -0,0 +1,102 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import { createContext, ParentDialog } from './utils';
import { AzureModelsComponent } from '../../../views/models/azureModelsComponent';
import { ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName } from '../../../views/models/modelViewBase';
import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
describe('Azure Models Component', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent);
view.registerComponent(testContext.view.modelBuilder);
should.notEqual(view.component, undefined);
});
it('Should load data successfully ', async function (): Promise<void> {
let testContext = createContext();
let parent = new ParentDialog(testContext.apiWrapper.object);
let view = new AzureModelsComponent(testContext.apiWrapper.object, parent, false);
view.registerComponent(testContext.view.modelBuilder);
let accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
let subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
let groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3'
}
];
let workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
let models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
parent.on(ListAccountsEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
});
parent.on(ListSubscriptionsEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
});
parent.on(ListGroupsEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
});
parent.on(ListWorkspacesEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
});
parent.on(ListAzureModelsEventName, () => {
parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
});
await view.refresh();
testContext.onClick.fire(true);
should.notEqual(view.data, undefined);
should.equal(view.data?.length, 1);
if (view.data) {
should.deepEqual(view.data[0].account, accounts[0]);
should.deepEqual(view.data[0].subscription, subscriptions[0]);
should.deepEqual(view.data[0].group, groups[0]);
should.deepEqual(view.data[0].workspace, workspaces[0]);
should.deepEqual(view.data[0].model, models[0]);
}
});
});

View File

@@ -0,0 +1,33 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { EditModelDialog } from '../../../views/models/manageModels/editModelDialog';
describe('Edit Model Dialog', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
const model: ImportedModel =
{
id: 1,
modelName: 'name1',
description: 'desc1',
created: '2018-01-01',
version: '1.1',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
};
let view = new EditModelDialog(testContext.apiWrapper.object, '', undefined, model);
view.open();
should.notEqual(view.dialogView, undefined);
});
});

View File

@@ -0,0 +1,203 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import {
ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName,
ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType
}
from '../../../views/models/modelViewBase';
import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { PredictWizard } from '../../../views/models/prediction/predictWizard';
import { DatabaseTable, TableColumn } from '../../../prediction/interfaces';
describe('Predict Wizard', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new PredictWizard(testContext.apiWrapper.object, '');
await view.open();
should.notEqual(view.wizardView, undefined);
should.notEqual(view.modelSourcePage, undefined);
});
it('Should load data successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new PredictWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
await view.open();
let accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
let subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
let groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3'
}
];
let workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
let models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
let localModels: ImportedModel[] = [
{
id: 1,
modelName: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
const dbNames: string[] = [
'db1',
'db2'
];
const tableNames: DatabaseTable[] = [
{
databaseName: 'db1',
schema: 'dbo',
tableName: 'tb1'
},
{
databaseName: 'db1',
tableName: 'tb2',
schema: 'dbo'
}
];
const columnNames: TableColumn[] = [
{
columnName: 'c1',
dataType: 'int'
},
{
columnName: 'c2',
dataType: 'varchar'
}
];
const modelParameters: ModelParameters = {
inputs: [
{
'name': 'p1',
'type': 'int'
},
{
'name': 'p2',
'type': 'varchar'
}
],
outputs: [
{
'name': 'o1',
'type': 'int'
}
]
};
view.on(ListModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
});
view.on(ListAccountsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
});
view.on(ListSubscriptionsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
});
view.on(ListGroupsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
});
view.on(ListWorkspacesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
});
view.on(ListAzureModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
});
view.on(ListDatabaseNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames });
});
view.on(ListTableNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames });
});
view.on(ListColumnNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames });
});
view.on(LoadModelParametersEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters });
});
view.on(DownloadAzureModelEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' });
});
view.on(DownloadRegisteredModelEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' });
});
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
}
await view.refresh();
should.notEqual(view.azureModelsComponent?.data, undefined);
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.RegisteredModels;
}
await view.refresh();
testContext.onClick.fire(undefined);
should.equal(view.modelSourcePage?.data, ModelSourceType.RegisteredModels);
should.notEqual(view.localModelsComponent?.data, undefined);
should.notEqual(view.modelBrowsePage?.registeredModelsComponent?.data, undefined);
if (view.modelBrowsePage?.registeredModelsComponent?.data) {
should.equal(view.modelBrowsePage.registeredModelsComponent.data.length, 1);
}
should.notEqual(await view.getModelFileName(), undefined);
await view.columnsSelectionPage?.onEnter();
should.notEqual(view.columnsSelectionPage?.data, undefined);
should.equal(view.columnsSelectionPage?.data?.inputColumns?.length, modelParameters.inputs.length, modelParameters.inputs[0].name);
should.equal(view.columnsSelectionPage?.data?.outputColumns?.length, modelParameters.outputs.length);
});
});

View File

@@ -0,0 +1,131 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { azureResource } from '../../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ViewBase } from '../../../views/viewBase';
import { WorkspaceModel } from '../../../modelManagement/interfaces';
import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard';
describe('Register Model Wizard', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'table',
schema: 'dbo'
};
await view.open();
should.notEqual(view.wizardView, undefined);
should.notEqual(view.modelSourcePage, undefined);
});
it('Should load data successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new ImportModelWizard(testContext.apiWrapper.object, '');
view.importTable = {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
};
await view.open();
let accounts: azdata.Account[] = [
{
key: {
accountId: '1',
providerId: ''
},
displayInfo: {
displayName: 'account',
userId: '',
accountType: '',
contextualDisplayName: ''
},
isStale: false,
properties: []
}
];
let subscriptions: azureResource.AzureResourceSubscription[] = [
{
name: 'subscription',
id: '2'
}
];
let groups: azureResource.AzureResourceResourceGroup[] = [
{
name: 'group',
id: '3'
}
];
let workspaces: Workspace[] = [
{
name: 'workspace',
id: '4'
}
];
let models: WorkspaceModel[] = [
{
id: '5',
name: 'model'
}
];
let localModels: ImportedModel[] = [
{
id: 1,
modelName: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
view.on(ListModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels });
});
view.on(ListDatabaseNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [
'db', 'db1'
] });
});
view.on(ListTableNamesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [
'tb', 'tb1'
] });
});
view.on(ListAccountsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts });
});
view.on(ListSubscriptionsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions });
});
view.on(ListGroupsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups });
});
view.on(ListWorkspacesEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces });
});
view.on(ListAzureModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models });
});
if (view.modelBrowsePage) {
view.modelBrowsePage.modelSourceType = ModelSourceType.Azure;
}
await view.refresh();
should.notEqual(view.azureModelsComponent?.data ,undefined);
should.notEqual(view.localModelsComponent?.data, undefined);
});
});

View File

@@ -0,0 +1,46 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as should from 'should';
import 'mocha';
import { createContext } from './utils';
import { ManageModelsDialog } from '../../../views/models/manageModels/manageModelsDialog';
import { ListModelsEventName } from '../../../views/models/modelViewBase';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { ViewBase } from '../../../views/viewBase';
describe('Registered Models Dialog', () => {
it('Should create view components successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
view.open();
should.notEqual(view.dialogView, undefined);
should.notEqual(view.currentLanguagesTab, undefined);
});
it('Should load data successfully ', async function (): Promise<void> {
let testContext = createContext();
let view = new ManageModelsDialog(testContext.apiWrapper.object, '');
view.open();
let models: ImportedModel[] = [
{
id: 1,
modelName: 'model',
table: {
databaseName: 'db',
tableName: 'tb',
schema: 'dbo'
}
}
];
view.on(ListModelsEventName, () => {
view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: models });
});
await view.refresh();
});
});

View File

@@ -0,0 +1,50 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../../common/apiWrapper';
import { createViewContext } from '../utils';
import { ModelViewBase } from '../../../views/models/modelViewBase';
import { AzureModelRegistryService } from '../../../modelManagement/azureModelRegistryService';
import { DeployedModelService } from '../../../modelManagement/deployedModelService';
import { PredictService } from '../../../prediction/predictService';
export interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
view: azdata.ModelView;
onClick: vscode.EventEmitter<any>;
azureModelService: TypeMoq.IMock<AzureModelRegistryService>;
deployModelService: TypeMoq.IMock<DeployedModelService>;
predictService: TypeMoq.IMock<PredictService>;
}
export class ParentDialog extends ModelViewBase {
public refresh(): Promise<void> {
return Promise.resolve();
}
public reset(): Promise<void> {
return Promise.resolve();
}
constructor(
apiWrapper: ApiWrapper) {
super(apiWrapper, '');
}
}
export function createContext(): TestContext {
let viewTestContext = createViewContext();
return {
apiWrapper: viewTestContext.apiWrapper,
view: viewTestContext.view,
onClick: viewTestContext.onClick,
azureModelService: TypeMoq.Mock.ofType(AzureModelRegistryService),
deployModelService: TypeMoq.Mock.ofType(DeployedModelService),
predictService: TypeMoq.Mock.ofType(PredictService)
};
}

View File

@@ -0,0 +1,333 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../common/apiWrapper';
export interface ViewTestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
view: azdata.ModelView;
onClick: vscode.EventEmitter<any>;
}
export function createViewContext(): ViewTestContext {
let onClick: vscode.EventEmitter<any> = new vscode.EventEmitter<any>();
let apiWrapper = TypeMoq.Mock.ofType(ApiWrapper);
let componentBase: azdata.Component = {
id: '',
updateProperties: () => Promise.resolve(),
updateProperty: () => Promise.resolve(),
updateCssStyles: undefined!,
onValidityChanged: undefined!,
valid: true,
validate: undefined!,
focus: undefined!
};
let button: azdata.ButtonComponent = Object.assign({}, componentBase, {
onDidClick: onClick.event
});
let link: azdata.HyperlinkComponent = Object.assign({}, componentBase, {
onDidClick: onClick.event,
label: '',
url: ''
});
let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, {
checked: true,
onDidClick: onClick.event
});
let checkbox: azdata.CheckBoxComponent = Object.assign({}, componentBase, {
checked: true,
onChanged: onClick.event
});
let container = {
clearItems: () => { },
addItems: () => { },
addItem: () => { },
removeItem: () => true,
insertItem: () => { },
items: [],
setLayout: () => { }
};
let form: azdata.FormContainer = Object.assign({}, componentBase, container, {
});
let flex: azdata.FlexContainer = Object.assign({}, componentBase, container, {
});
let div: azdata.DivContainer = Object.assign({}, componentBase, container, {
onDidClick: onClick.event
});
let buttonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
component: () => button,
withProperties: () => buttonBuilder,
withValidation: () => buttonBuilder
};
let hyperLinkBuilder: azdata.ComponentBuilder<azdata.HyperlinkComponent> = {
component: () => link,
withProperties: () => hyperLinkBuilder,
withValidation: () => hyperLinkBuilder
};
let radioButtonBuilder: azdata.ComponentBuilder<azdata.ButtonComponent> = {
component: () => radioButton,
withProperties: () => radioButtonBuilder,
withValidation: () => radioButtonBuilder
};
let checkBoxBuilder: azdata.ComponentBuilder<azdata.CheckBoxComponent> = {
component: () => checkbox,
withProperties: () => checkBoxBuilder,
withValidation: () => checkBoxBuilder
};
let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, {
onTextChanged: onClick.event!,
onEnterKeyPressed: undefined!,
value: ''
});
let image: () => azdata.ImageComponent = () => Object.assign({}, componentBase, {
});
let dropdown: () => azdata.DropDownComponent = () => Object.assign({}, componentBase, {
onValueChanged: onClick.event,
value: {
name: '',
displayName: ''
},
values: []
});
let declarativeTable: () => azdata.DeclarativeTableComponent = () => Object.assign({}, componentBase, {
onDataChanged: undefined!,
data: [],
columns: []
});
let loadingComponent: () => azdata.LoadingComponent = () => Object.assign({}, componentBase, {
loading: false,
component: undefined!
});
let card: () => azdata.CardComponent = () => Object.assign({}, componentBase, {
label: '',
onDidActionClick: new vscode.EventEmitter<azdata.ActionDescriptor>().event,
onCardSelectedChanged: onClick.event
});
let declarativeTableBuilder: azdata.ComponentBuilder<azdata.DeclarativeTableComponent> = {
component: () => declarativeTable(),
withProperties: () => declarativeTableBuilder,
withValidation: () => declarativeTableBuilder
};
let loadingBuilder: azdata.LoadingComponentBuilder = {
component: () => loadingComponent(),
withProperties: () => loadingBuilder,
withValidation: () => loadingBuilder,
withItem: () => loadingBuilder
};
let formBuilder: azdata.FormBuilder = Object.assign({}, {
component: () => form,
addFormItem: () => { },
insertFormItem: () => { },
removeFormItem: () => true,
addFormItems: () => { },
withFormItems: () => formBuilder,
withProperties: () => formBuilder,
withValidation: () => formBuilder,
withItems: () => formBuilder,
withLayout: () => formBuilder
});
let flexBuilder: azdata.FlexBuilder = Object.assign({}, {
component: () => flex,
withProperties: () => flexBuilder,
withValidation: () => flexBuilder,
withItems: () => flexBuilder,
withLayout: () => flexBuilder
});
let divBuilder: azdata.DivBuilder = Object.assign({}, {
component: () => div,
withProperties: () => divBuilder,
withValidation: () => divBuilder,
withItems: () => divBuilder,
withLayout: () => divBuilder
});
let inputBoxBuilder: azdata.ComponentBuilder<azdata.InputBoxComponent> = {
component: () => {
let r = inputBox();
return r;
},
withProperties: () => inputBoxBuilder,
withValidation: () => inputBoxBuilder
};
let cardBuilder: azdata.ComponentBuilder<azdata.CardComponent> = {
component: () => {
let r = card();
return r;
},
withProperties: () => cardBuilder,
withValidation: () => cardBuilder
};
let imageBuilder: azdata.ComponentBuilder<azdata.ImageComponent> = {
component: () => {
let r = image();
return r;
},
withProperties: () => imageBuilder,
withValidation: () => imageBuilder
};
let dropdownBuilder: azdata.ComponentBuilder<azdata.DropDownComponent> = {
component: () => {
let r = dropdown();
return r;
},
withProperties: () => dropdownBuilder,
withValidation: () => dropdownBuilder
};
let view: azdata.ModelView = {
onClosed: undefined!,
connection: undefined!,
serverInfo: undefined!,
valid: true,
onValidityChanged: undefined!,
validate: undefined!,
initializeModel: () => { return Promise.resolve(); },
modelBuilder: {
radioCardGroup: undefined!,
navContainer: undefined!,
divContainer: () => divBuilder,
flexContainer: () => flexBuilder,
splitViewContainer: undefined!,
dom: undefined!,
card: () => cardBuilder,
inputBox: () => inputBoxBuilder,
checkBox: () => checkBoxBuilder!,
radioButton: () => radioButtonBuilder,
webView: undefined!,
editor: undefined!,
diffeditor: undefined!,
text: () => inputBoxBuilder,
image: () => imageBuilder,
button: () => buttonBuilder,
dropDown: () => dropdownBuilder,
tree: undefined!,
listBox: undefined!,
table: undefined!,
declarativeTable: () => declarativeTableBuilder,
dashboardWidget: undefined!,
dashboardWebview: undefined!,
formContainer: () => formBuilder,
groupContainer: undefined!,
toolbarContainer: undefined!,
loadingComponent: () => loadingBuilder,
fileBrowserTree: undefined!,
hyperlink: () => hyperLinkBuilder,
tabbedPanel: undefined!,
separator: undefined!,
propertiesContainer: undefined!
}
};
let tab: azdata.window.DialogTab = {
title: '',
content: '',
registerContent: async (handler) => {
try {
await handler(view);
} catch (err) {
throw err;
}
},
onValidityChanged: undefined!,
valid: true,
modelView: undefined!
};
let dialogButton: azdata.window.Button = {
label: '',
enabled: true,
hidden: false,
onClick: onClick.event,
};
let dialogMessage: azdata.window.DialogMessage = {
text: '',
};
let dialog: azdata.window.Dialog = {
title: '',
isWide: false,
content: [],
okButton: dialogButton,
cancelButton: dialogButton,
customButtons: [],
message: dialogMessage,
registerCloseValidator: () => { },
registerOperation: () => { },
onValidityChanged: new vscode.EventEmitter<boolean>().event,
registerContent: () => { },
modelView: undefined!,
valid: true
};
let wizard: azdata.window.Wizard = {
title: '',
pages: [],
currentPage: 0,
doneButton: dialogButton,
cancelButton: dialogButton,
generateScriptButton: dialogButton,
nextButton: dialogButton,
backButton: dialogButton,
customButtons: [],
displayPageTitles: true,
onPageChanged: onClick.event,
addPage: () => { return Promise.resolve(); },
removePage: () => { return Promise.resolve(); },
setCurrentPage: () => { return Promise.resolve(); },
open: () => { return Promise.resolve(); },
close: () => { return Promise.resolve(); },
registerNavigationValidator: () => { },
message: dialogMessage,
registerOperation: () => { }
};
let wizardPage: azdata.window.WizardPage = {
title: '',
content: '',
customButtons: [],
enabled: true,
description: '',
onValidityChanged: onClick.event,
registerContent: async (handler) => {
try {
await handler(view);
} catch (err) {
throw err;
}
},
modelView: undefined!,
valid: true
};
apiWrapper.setup(x => x.createButton(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => dialogButton);
apiWrapper.setup(x => x.createTab(TypeMoq.It.isAny())).returns(() => tab);
apiWrapper.setup(x => x.createWizard(TypeMoq.It.isAny())).returns(() => wizard);
apiWrapper.setup(x => x.createWizardPage(TypeMoq.It.isAny())).returns(() => wizardPage);
apiWrapper.setup(x => x.createModelViewDialog(TypeMoq.It.isAny())).returns(() => dialog);
apiWrapper.setup(x => x.openDialog(TypeMoq.It.isAny())).returns(() => { });
apiWrapper.setup(x => x.registerWidget(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(async (id, handler) => {
if (id) {
return await handler(view);
} else {
Promise.reject();
}
});
return {
apiWrapper: apiWrapper,
view: view,
onClick: onClick,
};
}

View File

@@ -0,0 +1,22 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
const _typeof = {
undefined: 'undefined'
};
/**
* @returns whether the provided parameter is undefined or null.
*/
export function isUndefinedOrNull(obj: any): boolean {
return isUndefined(obj) || obj === null;
}
/**
* @returns whether the provided parameter is undefined.
*/
export function isUndefined(obj: any): boolean {
return typeof (obj) === _typeof.undefined;
}

View File

@@ -0,0 +1,28 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { Account } from 'azdata';
import * as msRest from '@azure/ms-rest-js';
export namespace azureResource {
export interface AzureResource {
name: string;
id: string;
}
export interface AzureResourceSubscription extends AzureResource {
}
export interface AzureResourceResourceGroup extends AzureResource {
}
export interface IAzureResourceService<T extends AzureResource> {
getResources(subscription: AzureResourceSubscription, credential: msRest.ServiceClientCredentials): Promise<T[]>;
}
export type GetSubscriptionsResult = { subscriptions: AzureResourceSubscription[], errors: Error[] };
export type GetResourceGroupsResult = { resourceGroups: AzureResourceResourceGroup[], errors: Error[] };
}

View File

@@ -0,0 +1,106 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
/**
* The API provided by this extension.
*
* @export
*/
export interface IExtensionApi {
getJupyterController(): IJupyterController;
registerPackageManager(providerId: string, packageManagerProvider: IPackageManageProvider): void
getPackageManagers(): Map<string, IPackageManageProvider>
}
export interface IJupyterController {
jupyterInstallation: IJupyterServerInstallation;
}
export interface IJupyterServerInstallation {
installPipPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise<void>;
uninstallPipPackages(packages: IPackageDetails[]): Promise<void>;
installCondaPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise<void>;
uninstallCondaPackages(packages: IPackageDetails[]): Promise<void>;
getInstalledPipPackages(): Promise<IPackageDetails[]>;
pythonExecutable: string;
pythonInstallationPath: string;
executeBufferedCommand(command: string): Promise<string>;
executeStreamedCommand(command: string): Promise<void>;
installPythonPackage(backgroundOperation: azdata.BackgroundOperation, usingExistingPython: boolean, pythonInstallationPath: string, outputChannel: vscode.OutputChannel): Promise<void>;
}
export interface IPackageDetails {
name: string;
version: string;
}
export interface IPackageTarget {
location: string;
packageType: string;
}
export interface IPackageOverview {
name: string;
versions: string[];
summary: string;
}
export interface IPackageLocation {
name: string;
displayName: string;
}
/**
* Package manage provider interface
*/
export interface IPackageManageProvider {
/**
* Provider id
*/
providerId: string;
/**
* package target
*/
packageTarget: IPackageTarget;
/**
* Returns list of installed packages
*/
listPackages(location?: string): Promise<IPackageDetails[]>;
/**
* Installs give packages
* @param package Packages to install
* @param useMinVersion if true, minimal version will be used
*/
installPackages(package: IPackageDetails[], useMinVersion: boolean, location?: string): Promise<void>;
/**
* Uninstalls given packages
* @param package package to uninstall
*/
uninstallPackages(package: IPackageDetails[], location?: string): Promise<void>;
/**
* Returns true if the provider can be used in current context
*/
canUseProvider(): Promise<boolean>;
/**
* Returns location title
*/
getLocations(): Promise<IPackageLocation[]>;
/**
* Returns Package Overview
* @param packageName package name
*/
getPackageOverview(packageName: string): Promise<IPackageOverview>;
}

View File

@@ -0,0 +1,9 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
/// <reference path='../../../../src/vs/vscode.d.ts'/>
/// <reference path='../../../../src/sql/azdata.d.ts'/>
/// <reference path='../../../../src/sql/azdata.proposed.d.ts'/>
/// <reference types='@types/node'/>

View File

@@ -0,0 +1,54 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import { ViewBase, LocalPathsEventName } from './viewBase';
import { ApiWrapper } from '../common/apiWrapper';
/**
* Base classes for UI controllers
*/
export abstract class ControllerBase {
/**
* creates new instance
*/
constructor(protected _apiWrapper: ApiWrapper) {
}
/**
* Executes an action and sends back callback event to the view
*/
public async executeAction<T extends ViewBase>(dialog: T, eventName: string, func: (...args: any[]) => Promise<any>, ...args: any[]): Promise<void> {
const callbackEvent = ViewBase.getCallbackEventName(eventName);
try {
let result = await func(...args);
dialog.sendCallbackRequest(callbackEvent, { data: result });
} catch (error) {
dialog.sendCallbackRequest(callbackEvent, { error: error });
}
}
/**
* Register common events for views
* @param view view
*/
public registerEvents(view: ViewBase): void {
view.on(LocalPathsEventName, async (args) => {
await this.executeAction(view, LocalPathsEventName, this.getLocalPaths, this._apiWrapper, args);
});
}
/**
* Returns local file path picked by the user
* @param apiWrapper apiWrapper
*/
public async getLocalPaths(apiWrapper: ApiWrapper, options: vscode.OpenDialogOptions): Promise<string[]> {
let result = await apiWrapper.showOpenDialog(options);
return result ? result?.map(x => x.fsPath) : [];
}
}

View File

@@ -0,0 +1,41 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import { MainViewBase } from './mainViewBase';
import { IPageView } from './interfaces';
/**
* Dialog view to create and manage a dialog
*/
export class DialogView extends MainViewBase {
private _dialog: azdata.window.Dialog | undefined;
/**
* Creates new instance
*/
constructor(apiWrapper: ApiWrapper) {
super(apiWrapper);
}
private createDialogPage(title: string, componentView: IPageView): azdata.window.DialogTab {
let viewPanel = this._apiWrapper.createTab(title);
this.addPage(componentView);
this.registerContent(viewPanel, componentView);
return viewPanel;
}
/**
* Creates a new dialog
* @param title title
* @param pages pages
*/
public createDialog(title: string, pages: IPageView[]): azdata.window.Dialog {
this._dialog = this._apiWrapper.createModelViewDialog(title);
this._dialog.content = pages.map(x => this.createDialogPage(x.title || '', x));
return this._dialog;
}
}

View File

@@ -0,0 +1,89 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import { LanguageViewBase, LanguageUpdateModel } from './languageViewBase';
import { LanguageContentView } from './languageContentView';
import { ApiWrapper } from '../../common/apiWrapper';
export class AddEditLanguageTab extends LanguageViewBase {
private _dialogTab: azdata.window.DialogTab;
public languageName: azdata.TextComponent | undefined;
private _editMode: boolean = false;
public saveButton: azdata.ButtonComponent | undefined;
public languageView: LanguageContentView | undefined;
constructor(
apiWrapper: ApiWrapper,
parent: LanguageViewBase,
private _languageUpdateModel: LanguageUpdateModel) {
super(apiWrapper, parent.root, parent);
this._editMode = !this._languageUpdateModel.newLang;
this._dialogTab = apiWrapper.createTab(constants.extLangNewLanguageTabTitle);
this._dialogTab.registerContent(async view => {
let language = this._languageUpdateModel.language;
let content = this._languageUpdateModel.content;
this.languageName = view.modelBuilder.inputBox().withProperties({
value: language.name,
width: '150px',
enabled: !this._editMode
}).withValidation(component => component.value !== '').component();
let formBuilder = view.modelBuilder.formContainer();
formBuilder.addFormItem({
component: this.languageName,
title: constants.extLangLanguageName,
required: true
});
this.languageView = new LanguageContentView(this._apiWrapper, this, view.modelBuilder, formBuilder, content);
if (!this._editMode) {
this.saveButton = view.modelBuilder.button().withProperties({
label: constants.extLangInstallButtonText,
width: '100px'
}).component();
this.saveButton.onDidClick(async () => {
try {
await this.updateLanguage(this.updatedData);
} catch (err) {
this.showErrorMessage(constants.extLangInstallFailedError, err);
}
});
formBuilder.addFormItem({
component: this.saveButton,
title: ''
});
}
await view.initializeModel(formBuilder.component());
await this.reset();
});
}
public get updatedData(): LanguageUpdateModel {
return {
language: {
name: this.languageName?.value || '',
contents: this._languageUpdateModel.language.contents
},
content: this.languageView?.updatedContent || this._languageUpdateModel.content,
newLang: this._languageUpdateModel.newLang
};
}
public get tab(): azdata.window.DialogTab {
return this._dialogTab;
}
public async reset(): Promise<void> {
if (this.languageName) {
this.languageName.value = this._languageUpdateModel.language.name;
}
this.languageView?.reset();
}
}

View File

@@ -0,0 +1,85 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import { LanguageViewBase } from './languageViewBase';
import { LanguagesTable } from './languagesTable';
import { ApiWrapper } from '../../common/apiWrapper';
export class CurrentLanguagesTab extends LanguageViewBase {
private _installedLangsTab: azdata.window.DialogTab;
private _locationComponent: azdata.TextComponent | undefined;
private _installLanguagesTable: azdata.DeclarativeTableComponent | undefined;
private _languageTable: LanguagesTable | undefined;
private _loader: azdata.LoadingComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: LanguageViewBase) {
super(apiWrapper, parent.root, parent);
this._installedLangsTab = this._apiWrapper.createTab(constants.extLangInstallTabTitle);
this._installedLangsTab.registerContent(async view => {
// TODO: only supporting single location for now. We should add a drop down for multi locations mode
//
let locationTitle = await this.getServerTitle();
this._locationComponent = view.modelBuilder.text().withProperties({
value: locationTitle
}).component();
this._languageTable = new LanguagesTable(apiWrapper, view.modelBuilder, this);
this._installLanguagesTable = this._languageTable.table;
let formModel = view.modelBuilder.formContainer()
.withFormItems([{
component: this._locationComponent,
title: constants.extLangTarget
}, {
component: this._installLanguagesTable,
title: ''
}]).component();
this._loader = view.modelBuilder.loadingComponent()
.withItem(formModel)
.withProperties({
loading: true
}).component();
await view.initializeModel(this._loader);
await this.reset();
});
}
public get tab(): azdata.window.DialogTab {
return this._installedLangsTab;
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
private async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
public async reset(): Promise<void> {
await this.onLoading();
try {
await this._languageTable?.reset();
} catch (err) {
this.showErrorMessage(constants.getErrorMessage(err));
} finally {
await this.onLoaded();
}
}
}

View File

@@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../common/constants';
import { ApiWrapper } from '../../common/apiWrapper';
export class FileBrowserDialog {
private _selectedPathTextBox: azdata.InputBoxComponent | undefined;
private _fileBrowserDialog: azdata.window.Dialog | undefined;
private _fileBrowserTree: azdata.FileBrowserTreeComponent | undefined;
private _onPathSelected: vscode.EventEmitter<string> = new vscode.EventEmitter<string>();
public readonly onPathSelected: vscode.Event<string> = this._onPathSelected.event;
constructor(private _apiWrapper: ApiWrapper, private ownerUri: string) {
}
/**
* Opens a dialog to browse server files and folders.
*/
public showDialog(): void {
let fileBrowserTitle = '';
this._fileBrowserDialog = this._apiWrapper.createModelViewDialog(fileBrowserTitle);
let fileBrowserTab = this._apiWrapper.createTab(constants.extLangFileBrowserTabTitle);
this._fileBrowserDialog.content = [fileBrowserTab];
fileBrowserTab.registerContent(async (view) => {
this._fileBrowserTree = view.modelBuilder.fileBrowserTree()
.withProperties({ ownerUri: this.ownerUri, width: 420, height: 700 })
.component();
this._selectedPathTextBox = view.modelBuilder.inputBox()
.withProperties({ inputType: 'text' })
.component();
this._fileBrowserTree.onDidChange((args) => {
if (this._selectedPathTextBox) {
this._selectedPathTextBox.value = args.fullPath;
}
});
let fileBrowserContainer = view.modelBuilder.formContainer()
.withFormItems([{
component: this._fileBrowserTree,
title: ''
}, {
component: this._selectedPathTextBox,
title: constants.extLangSelectedPath
}
]).component();
view.initializeModel(fileBrowserContainer);
});
this._fileBrowserDialog.okButton.onClick(() => {
if (this._selectedPathTextBox) {
let selectedPath = this._selectedPathTextBox.value || '';
this._onPathSelected.fire(selectedPath);
}
});
this._fileBrowserDialog.cancelButton.onClick(() => {
this._onPathSelected.fire('');
});
this._fileBrowserDialog.okButton.label = constants.extLangOkButtonText;
this._fileBrowserDialog.cancelButton.label = constants.extLangCancelButtonText;
this._apiWrapper.openDialog(this._fileBrowserDialog);
}
}

View File

@@ -0,0 +1,156 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as mssql from '../../../../mssql';
import { LanguageViewBase } from './languageViewBase';
import * as constants from '../../common/constants';
import { ApiWrapper } from '../../common/apiWrapper';
export class LanguageContentView extends LanguageViewBase {
private _serverPath: azdata.RadioButtonComponent;
private _localPath: azdata.RadioButtonComponent;
public extensionFile: azdata.TextComponent;
public extensionFileName: azdata.TextComponent;
public envVariables: azdata.TextComponent;
public parameters: azdata.TextComponent;
private _isLocalPath: boolean = true;
/**
*
*/
constructor(
apiWrapper: ApiWrapper,
parent: LanguageViewBase,
private _modelBuilder: azdata.ModelBuilder,
private _formBuilder: azdata.FormBuilder,
private _languageContent: mssql.ExternalLanguageContent | undefined,
) {
super(apiWrapper, parent.root, parent);
this._localPath = this._modelBuilder.radioButton()
.withProperties({
value: 'local',
name: 'extensionLocation',
label: constants.extLangLocal,
checked: true
}).component();
this._serverPath = this._modelBuilder.radioButton()
.withProperties({
value: 'server',
name: 'extensionLocation',
label: this.getServerTitle(),
}).component();
this._localPath.onDidClick(() => {
this._isLocalPath = true;
});
this._serverPath.onDidClick(() => {
this._isLocalPath = false;
});
let flexRadioButtonsModel = this._modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
//width: parent.componentMaxLength
}).withItems([
this._localPath, this._serverPath]
).component();
this.extensionFile = this._modelBuilder.inputBox().withProperties({
value: '',
width: parent.componentMaxLength - parent.browseButtonMaxLength - parent.spaceBetweenComponentsLength
}).component();
let fileBrowser = this._modelBuilder.button().withProperties({
label: '...',
width: parent.browseButtonMaxLength,
CSSStyles: {
'text-align': 'end'
}
}).component();
let flexFilePathModel = this._modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
}).withItems([
this.extensionFile, fileBrowser]
).component();
this.filePathSelected(args => {
this.extensionFile.value = args.filePath;
});
fileBrowser.onDidClick(async () => {
this.onOpenFileBrowser({ filePath: '', target: this._isLocalPath ? constants.localhost : this.connectionUrl });
});
this.extensionFileName = this._modelBuilder.inputBox().withProperties({
value: '',
width: parent.componentMaxLength
}).component();
this.envVariables = this._modelBuilder.inputBox().withProperties({
value: '',
width: parent.componentMaxLength
}).component();
this.parameters = this._modelBuilder.inputBox().withProperties({
value: '',
width: parent.componentMaxLength
}).component();
this.load();
this._formBuilder.addFormItems([{
component: flexRadioButtonsModel,
title: constants.extLangExtensionFileLocation
}, {
component: flexFilePathModel,
title: constants.extLangExtensionFilePath,
required: true
}, {
component: this.extensionFileName,
title: constants.extLangExtensionFileName,
required: true
}, {
component: this.envVariables,
title: constants.extLangEnvVariables
}, {
component: this.parameters,
title: constants.extLangParameters
}]);
}
private load() {
if (this._languageContent) {
this._isLocalPath = this._languageContent.isLocalFile;
this._localPath.checked = this._isLocalPath;
this._serverPath.checked = !this._isLocalPath;
this.extensionFile.value = this._languageContent.pathToExtension;
this.extensionFileName.value = this._languageContent.extensionFileName;
this.envVariables.value = this._languageContent.environmentVariables;
this.parameters.value = this._languageContent.parameters;
}
}
public async reset(): Promise<void> {
this._isLocalPath = true;
this._localPath.checked = this._isLocalPath;
this._serverPath.checked = !this._isLocalPath;
this.load();
}
public get updatedContent(): mssql.ExternalLanguageContent {
return {
pathToExtension: this.extensionFile.value || '',
extensionFileName: this.extensionFileName.value || '',
parameters: this.parameters.value || '',
environmentVariables: this.envVariables.value || '',
isLocalFile: this._isLocalPath || false,
platform: this._languageContent?.platform
};
}
}

View File

@@ -0,0 +1,143 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as mssql from '../../../../mssql';
import { ApiWrapper } from '../../common/apiWrapper';
import { LanguageService } from '../../externalLanguage/languageService';
import { LanguagesDialog } from './languagesDialog';
import { LanguageEditDialog } from './languageEditDialog';
import { FileBrowserDialog } from './fileBrowserDialog';
import { LanguageViewBase, LanguageUpdateModel } from './languageViewBase';
import * as constants from '../../common/constants';
export class LanguageController {
/**
*
*/
constructor(
private _apiWrapper: ApiWrapper,
private _root: string,
private _service: LanguageService) {
}
/**
* Opens the manage language dialog and connects events to the model
*/
public async manageLanguages(): Promise<LanguagesDialog> {
let dialog = new LanguagesDialog(this._apiWrapper, this._root);
// Load current connection
//
await this._service.load();
dialog.connection = this._service.connection;
dialog.connectionUrl = this._service.connectionUrl;
// Handle dialog events and connect to model
//
dialog.onEdit(model => {
this.editLanguage(dialog, model);
});
dialog.onDelete(async deleteModel => {
try {
await this.executeAction(dialog, this.deleteLanguage, this._service, deleteModel);
dialog.onUpdatedLanguage(deleteModel);
} catch (err) {
dialog.onActionFailed(err);
}
});
dialog.onUpdate(async updateModel => {
try {
await this.executeAction(dialog, this.updateLanguage, this._service, updateModel);
dialog.onUpdatedLanguage(updateModel);
} catch (err) {
dialog.onActionFailed(err);
}
});
dialog.onList(async () => {
try {
let result = await this.listLanguages(this._service);
dialog.onListLanguageLoaded(result);
} catch (err) {
dialog.onActionFailed(err);
}
});
this.onSelectFile(dialog);
// Open dialog
//
dialog.showDialog();
return dialog;
}
public async executeAction<T>(dialog: LanguageViewBase, func: (...args: any[]) => Promise<T>, ...args: any[]): Promise<T> {
let result = await func(...args);
await dialog.reset();
return result;
}
public editLanguage(parent: LanguageViewBase, languageUpdateModel: LanguageUpdateModel): void {
let editDialog = new LanguageEditDialog(this._apiWrapper, parent, languageUpdateModel);
editDialog.showDialog();
}
private onSelectFile(dialog: LanguageViewBase): void {
dialog.fileBrowser(async (args) => {
let filePath = '';
if (args.target === constants.localhost) {
filePath = await this.getLocalFilePath();
} else {
filePath = await this.getServerFilePath(args.target);
}
dialog.onFilePathSelected({ filePath: filePath, target: args.target });
});
}
public getServerFilePath(connectionUrl: string): Promise<string> {
return new Promise<string>((resolve) => {
let dialog = new FileBrowserDialog(this._apiWrapper, connectionUrl);
dialog.onPathSelected((selectedPath) => {
resolve(selectedPath);
});
dialog.showDialog();
});
}
public async getLocalFilePath(): Promise<string> {
let result = await this._apiWrapper.showOpenDialog({
canSelectFiles: true,
canSelectFolders: false,
canSelectMany: false
});
return result && result.length > 0 ? result[0].fsPath : '';
}
public async deleteLanguage(model: LanguageService, deleteModel: LanguageUpdateModel): Promise<void> {
await model.deleteLanguage(deleteModel.language.name);
}
public async listLanguages(model: LanguageService): Promise<mssql.ExternalLanguage[]> {
return await model.getLanguageList();
}
public async updateLanguage(model: LanguageService, updateModel: LanguageUpdateModel): Promise<void> {
if (!updateModel.language) {
return;
}
let contents: mssql.ExternalLanguageContent[] = [];
if (updateModel.language.contents && updateModel.language.contents.length >= 0) {
contents = updateModel.language.contents.filter(x => x.platform !== updateModel.content.platform);
}
contents.push(updateModel.content);
updateModel.language.contents = contents;
await model.updateLanguage(updateModel.language);
}
}

View File

@@ -0,0 +1,59 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as constants from '../../common/constants';
import { AddEditLanguageTab } from './addEditLanguageTab';
import { LanguageViewBase, LanguageUpdateModel } from './languageViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
export class LanguageEditDialog extends LanguageViewBase {
public addNewLanguageTab: AddEditLanguageTab | undefined;
constructor(
apiWrapper: ApiWrapper,
parent: LanguageViewBase,
private _languageUpdateModel: LanguageUpdateModel) {
super(apiWrapper, parent.root, parent);
}
/**
* Opens a dialog to edit a language or a content of a language
*/
public showDialog(): void {
this._dialog = this._apiWrapper.createModelViewDialog(constants.extLangDialogTitle);
this.addNewLanguageTab = new AddEditLanguageTab(this._apiWrapper, this, this._languageUpdateModel);
this._dialog.cancelButton.label = constants.extLangCancelButtonText;
this._dialog.okButton.label = constants.extLangSaveButtonText;
this.dialog?.registerCloseValidator(async (): Promise<boolean> => {
return await this.onSave();
});
this._dialog.content = [this.addNewLanguageTab.tab];
this._apiWrapper.openDialog(this._dialog);
}
public async onSave(): Promise<boolean> {
if (this.addNewLanguageTab) {
try {
await this.updateLanguage(this.addNewLanguageTab.updatedData);
return true;
} catch (err) {
this.showErrorMessage(constants.extLangUpdateFailedError, err);
return false;
}
}
return false;
}
/**
* Resets the tabs for given provider Id
*/
public async reset(): Promise<void> {
await this.addNewLanguageTab?.reset();
}
}

View File

@@ -0,0 +1,261 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../common/constants';
import { ApiWrapper } from '../../common/apiWrapper';
import * as mssql from '../../../../mssql';
import * as path from 'path';
export interface LanguageUpdateModel {
language: mssql.ExternalLanguage,
content: mssql.ExternalLanguageContent,
newLang: boolean
}
export interface FileBrowseEventArgs {
filePath: string,
target: string
}
export abstract class LanguageViewBase {
protected _dialog: azdata.window.Dialog | undefined;
public connection: azdata.connection.ConnectionProfile | undefined;
public connectionUrl: string = '';
// Events
//
protected _onEdit: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
public readonly onEdit: vscode.Event<LanguageUpdateModel> = this._onEdit.event;
protected _onUpdate: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
public readonly onUpdate: vscode.Event<LanguageUpdateModel> = this._onUpdate.event;
protected _onDelete: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
public readonly onDelete: vscode.Event<LanguageUpdateModel> = this._onDelete.event;
protected _fileBrowser: vscode.EventEmitter<FileBrowseEventArgs> = new vscode.EventEmitter<FileBrowseEventArgs>();
public readonly fileBrowser: vscode.Event<FileBrowseEventArgs> = this._fileBrowser.event;
protected _filePathSelected: vscode.EventEmitter<FileBrowseEventArgs> = new vscode.EventEmitter<FileBrowseEventArgs>();
public readonly filePathSelected: vscode.Event<FileBrowseEventArgs> = this._filePathSelected.event;
protected _onUpdated: vscode.EventEmitter<LanguageUpdateModel> = new vscode.EventEmitter<LanguageUpdateModel>();
public readonly onUpdated: vscode.Event<LanguageUpdateModel> = this._onUpdated.event;
protected _onList: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onList: vscode.Event<void> = this._onList.event;
protected _onListLoaded: vscode.EventEmitter<mssql.ExternalLanguage[]> = new vscode.EventEmitter<mssql.ExternalLanguage[]>();
public readonly onListLoaded: vscode.Event<mssql.ExternalLanguage[]> = this._onListLoaded.event;
protected _onFailed: vscode.EventEmitter<any> = new vscode.EventEmitter<any>();
public readonly onFailed: vscode.Event<any> = this._onFailed.event;
public componentMaxLength = 350;
public browseButtonMaxLength = 20;
public spaceBetweenComponentsLength = 10;
constructor(protected _apiWrapper: ApiWrapper, protected _root?: string, protected _parent?: LanguageViewBase,) {
if (this._parent) {
if (!this._root) {
this._root = this._parent.root;
}
this.connection = this._parent.connection;
this.connectionUrl = this._parent.connectionUrl;
}
this.registerEvents();
}
private registerEvents() {
if (this._parent) {
this._dialog = this._parent.dialog;
this.fileBrowser(url => {
this._parent?.onOpenFileBrowser(url);
});
this.onUpdate(model => {
this._parent?.onUpdateLanguage(model);
});
this.onEdit(model => {
this._parent?.onEditLanguage(model);
});
this.onDelete(model => {
this._parent?.onDeleteLanguage(model);
});
this.onList(() => {
this._parent?.onListLanguages();
});
this._parent.filePathSelected(x => {
this.onFilePathSelected(x);
});
this._parent.onUpdated(x => {
this.onUpdatedLanguage(x);
});
this._parent.onFailed(x => {
this.onActionFailed(x);
});
this._parent.onListLoaded(x => {
this.onListLanguageLoaded(x);
});
}
}
public async getLocationTitle(): Promise<string> {
let connection = await this.getCurrentConnection();
if (connection) {
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : constants.extLangLocal}`;
}
return constants.noConnectionError;
}
public getServerTitle(): string {
if (this.connection) {
return this.connection.serverName;
}
return constants.noConnectionError;
}
private async getCurrentConnectionUrl(): Promise<string> {
let connection = await this.getCurrentConnection();
if (connection) {
return await this._apiWrapper.getUriForConnection(connection.connectionId);
}
return '';
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
public async loadConnection(): Promise<void> {
this.connection = await this.getCurrentConnection();
this.connectionUrl = await this.getCurrentConnectionUrl();
}
public updateLanguage(updateModel: LanguageUpdateModel): Promise<void> {
return new Promise<void>((resolve, reject) => {
this.onUpdateLanguage(updateModel);
this.onUpdated(() => {
resolve();
});
this.onFailed(err => {
reject(err);
});
});
}
public deleteLanguage(model: LanguageUpdateModel): Promise<void> {
return new Promise<void>((resolve, reject) => {
this.onDeleteLanguage(model);
this.onUpdated(() => {
resolve();
});
this.onFailed(err => {
reject(err);
});
});
}
public listLanguages(): Promise<mssql.ExternalLanguage[]> {
return new Promise<mssql.ExternalLanguage[]>((resolve, reject) => {
this.onListLanguages();
this.onListLoaded(list => {
resolve(list);
});
this.onFailed(err => {
reject(err);
});
});
}
/**
* Dialog model instance
*/
public get dialog(): azdata.window.Dialog | undefined {
return this._dialog;
}
public set dialog(value: azdata.window.Dialog | undefined) {
this._dialog = value;
}
public showInfoMessage(message: string): void {
this.showMessage(message, azdata.window.MessageLevel.Information);
}
public showErrorMessage(message: string, error?: any): void {
this.showMessage(`${message} ${constants.getErrorMessage(error)}`, azdata.window.MessageLevel.Error);
}
public onUpdateLanguage(model: LanguageUpdateModel): void {
this._onUpdate.fire(model);
}
public onUpdatedLanguage(model: LanguageUpdateModel): void {
this._onUpdated.fire(model);
}
public onActionFailed(error: any): void {
this._onFailed.fire(error);
}
public onListLanguageLoaded(list: mssql.ExternalLanguage[]): void {
this._onListLoaded.fire(list);
}
public onEditLanguage(model: LanguageUpdateModel): void {
this._onEdit.fire(model);
}
public onDeleteLanguage(model: LanguageUpdateModel): void {
this._onDelete.fire(model);
}
public onListLanguages(): void {
this._onList.fire();
}
public onOpenFileBrowser(fileBrowseArgs: FileBrowseEventArgs): void {
this._fileBrowser.fire(fileBrowseArgs);
}
public onFilePathSelected(fileBrowseArgs: FileBrowseEventArgs): void {
this._filePathSelected.fire(fileBrowseArgs);
}
private showMessage(message: string, level: azdata.window.MessageLevel): void {
if (this._dialog) {
this._dialog.message = {
text: message,
level: level
};
}
}
public get root(): string {
return this._root || '';
}
public asAbsolutePath(filePath: string): string {
return path.join(this._root || '', filePath);
}
public abstract reset(): Promise<void>;
public createNewContent(): mssql.ExternalLanguageContent {
return {
extensionFileName: '',
isLocalFile: true,
pathToExtension: '',
};
}
public createNewLanguage(): mssql.ExternalLanguage {
return {
name: '',
contents: []
};
}
}

View File

@@ -0,0 +1,56 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { CurrentLanguagesTab } from './currentLanguagesTab';
import { AddEditLanguageTab } from './addEditLanguageTab';
import { LanguageViewBase } from './languageViewBase';
import * as constants from '../../common/constants';
import { ApiWrapper } from '../../common/apiWrapper';
export class LanguagesDialog extends LanguageViewBase {
public currentLanguagesTab: CurrentLanguagesTab | undefined;
public addNewLanguageTab: AddEditLanguageTab | undefined;
constructor(
apiWrapper: ApiWrapper,
root: string) {
super(apiWrapper, root);
}
/**
* Opens a dialog to manage packages used by notebooks.
*/
public showDialog(): void {
this.dialog = this._apiWrapper.createModelViewDialog(constants.extLangDialogTitle);
this.currentLanguagesTab = new CurrentLanguagesTab(this._apiWrapper, this);
let languageUpdateModel = {
language: this.createNewLanguage(),
content: this.createNewContent(),
newLang: true
};
this.addNewLanguageTab = new AddEditLanguageTab(this._apiWrapper, this, languageUpdateModel);
this.dialog.okButton.hidden = true;
this.dialog.cancelButton.label = constants.extLangDoneButtonText;
this.dialog.content = [this.currentLanguagesTab.tab, this.addNewLanguageTab.tab];
this.dialog.registerCloseValidator(() => {
return false; // Blocks Enter key from closing dialog.
});
this._apiWrapper.openDialog(this.dialog);
}
/**
* Resets the tabs for given provider Id
*/
public async reset(): Promise<void> {
await this.currentLanguagesTab?.reset();
await this.addNewLanguageTab?.reset();
}
}

View File

@@ -0,0 +1,165 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../common/constants';
import * as mssql from '../../../../mssql';
import { LanguageViewBase } from './languageViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
export class LanguagesTable extends LanguageViewBase {
private _table: azdata.DeclarativeTableComponent;
/**
*
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: LanguageViewBase) {
super(apiWrapper, parent.root, parent);
this._table = _modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Name
displayName: constants.extLangLanguageName,
ariaLabel: constants.extLangLanguageName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Platform
displayName: constants.extLangLanguagePlatform,
ariaLabel: constants.extLangLanguagePlatform,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created Date
displayName: constants.extLangLanguageCreatedDate,
ariaLabel: constants.extLangLanguageCreatedDate,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
}
public get table(): azdata.DeclarativeTableComponent {
return this._table;
}
public async loadData(): Promise<void> {
let languages: mssql.ExternalLanguage[] | undefined;
languages = await this.listLanguages();
let tableData: any[][] = [];
if (languages) {
languages.forEach(language => {
if (!language.contents || language.contents.length === 0) {
language.contents.push(this.createNewContent());
}
tableData = tableData.concat(language.contents.map(content => this.createTableRow(language, content)));
});
}
this._table.data = tableData;
}
private createTableRow(language: mssql.ExternalLanguage, content: mssql.ExternalLanguageContent): any[] {
if (this._modelBuilder) {
let dropLanguageButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
width: 15,
height: 15
}).component();
dropLanguageButton.onDidClick(async () => {
await this.deleteLanguage({
language: language,
content: content,
newLang: false
});
});
let editLanguageButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
width: 15,
height: 15
}).component();
editLanguageButton.onDidClick(() => {
this.onEditLanguage({
language: language,
content: content,
newLang: false
});
});
return [language.name, content.platform, language.createdDate, dropLanguageButton, editLanguageButton];
}
return [];
}
public async reset(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,43 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { azureResource } from '../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { WorkspaceModel } from '../modelManagement/interfaces';
export interface IDataComponent<T> {
data: T | undefined;
}
export interface IPageView {
registerComponent: (modelBuilder: azdata.ModelBuilder) => azdata.Component;
component: azdata.Component | undefined;
onEnter?: () => Promise<void>;
onLeave?: () => Promise<void>;
validate?: () => Promise<boolean>;
refresh: () => Promise<void>;
disposePage?: () => Promise<void>;
viewPanel: azdata.window.ModelViewPanel | undefined;
title: string;
}
export interface AzureWorkspaceResource {
account?: azdata.Account,
subscription?: azureResource.AzureResourceSubscription,
group?: azureResource.AzureResource,
workspace?: Workspace
}
export interface AzureModelResource extends AzureWorkspaceResource {
model?: WorkspaceModel;
}
export interface IComponentSettings {
multiSelect?: boolean;
editable?: boolean;
selectable?: boolean;
}

View File

@@ -0,0 +1,55 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import { IPageView } from './interfaces';
/**
* Base class for dialog and wizard
*/
export class MainViewBase {
protected _pages: IPageView[] = [];
/**
*
*/
constructor(protected _apiWrapper: ApiWrapper) {
}
protected registerContent(viewPanel: azdata.window.DialogTab | azdata.window.WizardPage, componentView: IPageView) {
viewPanel.registerContent(async view => {
if (componentView) {
let component = componentView.registerComponent(view.modelBuilder);
await view.initializeModel(component);
await componentView.refresh();
}
});
}
protected addPage(page: IPageView, index?: number): void {
if (index) {
this._pages[index] = page;
} else {
this._pages.push(page);
}
}
public async disposePages(): Promise<void> {
if (this._pages) {
await Promise.all(this._pages.map(async (p) => {
if (p.disposePage) {
await p.disposePage();
}
}));
}
}
public async refresh(): Promise<void> {
if (this._pages) {
await Promise.all(this._pages.map(async (p) => await p.refresh()));
}
}
}

View File

@@ -0,0 +1,170 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
import { AzureModelsTable } from './azureModelsTable';
import { IDataComponent, AzureModelResource } from '../interfaces';
import { ModelArtifact } from './prediction/modelArtifact';
import { AzureSignInComponent } from './azureSignInComponent';
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource[]> {
public azureModelsTable: AzureModelsTable | undefined;
public azureFilterComponent: AzureResourceFilterComponent | undefined;
public azureSignInComponent: AzureSignInComponent | undefined;
private _loader: azdata.LoadingComponent | undefined;
private _form: azdata.FormContainer | undefined;
private _downloadedFile: ModelArtifact | undefined;
/**
* Component to render a view to pick an azure model
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this.azureFilterComponent = new AzureResourceFilterComponent(this._apiWrapper, modelBuilder, this);
this.azureModelsTable = new AzureModelsTable(this._apiWrapper, modelBuilder, this, this._multiSelect);
this.azureSignInComponent = new AzureSignInComponent(this._apiWrapper, modelBuilder, this);
this._loader = modelBuilder.loadingComponent()
.withItem(this.azureModelsTable.component)
.withProperties({
loading: true
}).component();
this.azureModelsTable.onModelSelectionChanged(async () => {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
this._downloadedFile = undefined;
});
this.azureFilterComponent.onWorkspacesSelectedChanged(async () => {
await this.onLoading();
await this.azureModelsTable?.loadData(this.azureFilterComponent?.data);
await this.onLoaded();
});
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this.azureFilterComponent.component
}, {
title: '',
component: this._loader
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
this.removeComponents(formBuilder);
if (this.azureFilterComponent?.data?.account) {
this.addAzureComponents(formBuilder);
} else {
this.addAzureSignInComponents(formBuilder);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
this.removeAzureComponents(formBuilder);
this.removeAzureSignInComponents(formBuilder);
}
private addAzureComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.addComponents(formBuilder);
formBuilder.addFormItems([{
title: '',
component: this._loader
}]);
}
}
private removeAzureComponents(formBuilder: azdata.FormBuilder) {
if (this.azureFilterComponent && this._loader) {
this.azureFilterComponent.removeComponents(formBuilder);
formBuilder.removeFormItem({
title: '',
component: this._loader
});
}
}
private addAzureSignInComponents(formBuilder: azdata.FormBuilder) {
if (this.azureSignInComponent) {
this.azureSignInComponent.addComponents(formBuilder);
}
}
private removeAzureSignInComponents(formBuilder: azdata.FormBuilder) {
if (this.azureSignInComponent) {
this.azureSignInComponent.removeComponents(formBuilder);
}
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
private async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Loads the data in the components
*/
public async loadData(): Promise<void> {
await this.azureFilterComponent?.loadData();
}
/**
* Returns selected data
*/
public get data(): AzureModelResource[] | undefined {
return this.azureModelsTable?.data ? this.azureModelsTable?.data.map(x => Object.assign({}, this.azureFilterComponent?.data, {
model: x
})) : undefined;
}
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
const data = this.data;
if (!this._downloadedFile && data && data.length > 0) {
this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(data[0]));
}
return this._downloadedFile;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,184 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../common/constants';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { WorkspaceModel } from '../../modelManagement/interfaces';
import { IDataComponent, AzureWorkspaceResource } from '../interfaces';
/**
* View to render azure models in a table
*/
export class AzureModelsTable extends ModelViewBase implements IDataComponent<WorkspaceModel[]> {
private _table: azdata.DeclarativeTableComponent;
private _selectedModel: WorkspaceModel[] = [];
private _models: WorkspaceModel[] | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
/**
* Creates a view to render azure models in a table
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
this._table = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Version
displayName: constants.modelVersion,
ariaLabel: constants.modelVersion,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
}
public get component(): azdata.DeclarativeTableComponent {
return this._table;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadData(workspaceResource?: AzureWorkspaceResource | undefined): Promise<void> {
if (this._table && workspaceResource) {
this._models = await this.listAzureModels(workspaceResource);
let tableData: any[][] = [];
if (this._models) {
tableData = tableData.concat(this._models.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
this._onModelSelectionChanged.fire();
}
private createTableRow(model: WorkspaceModel): any[] {
if (this._modelBuilder) {
let selectModelButton: azdata.Component;
let onSelectItem = (checked: boolean) => {
const foundItem = this._selectedModel.find(x => x === model);
if (checked && !foundItem) {
this._selectedModel.push(model);
} else if (foundItem) {
this._selectedModel = this._selectedModel.filter(x => x !== model);
}
this._onModelSelectionChanged.fire();
};
if (this._multiSelect) {
const checkbox = this._modelBuilder.checkBox().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
checkbox.onChanged(() => {
onSelectItem(checkbox.checked || false);
});
selectModelButton = checkbox;
} else {
const radioButton = this._modelBuilder.radioButton().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
radioButton.onDidClick(() => {
onSelectItem(radioButton.checked || false);
});
selectModelButton = radioButton;
}
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
}
return [];
}
/**
* Returns selected data
*/
public get data(): WorkspaceModel[] | undefined {
if (this._models && this._selectedModel) {
return this._selectedModel;
}
return undefined;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,207 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import { azureResource } from '../../typings/azure-resource';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import * as constants from '../../common/constants';
import { AzureWorkspaceResource, IDataComponent } from '../interfaces';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 300;
export class AzureResourceFilterComponent extends ModelViewBase implements IDataComponent<AzureWorkspaceResource> {
private _form: azdata.FormContainer;
private _accounts: azdata.DropDownComponent;
private _subscriptions: azdata.DropDownComponent;
private _groups: azdata.DropDownComponent;
private _workspaces: azdata.DropDownComponent;
private _azureAccounts: azdata.Account[] = [];
private _azureSubscriptions: azureResource.AzureResourceSubscription[] = [];
private _azureGroups: azureResource.AzureResource[] = [];
private _azureWorkspaces: Workspace[] = [];
private _onWorkspacesSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onWorkspacesSelectedChanged: vscode.Event<void> = this._onWorkspacesSelectedChanged.event;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._accounts = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._subscriptions = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._groups = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._workspaces = this._modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
this._accounts.onValueChanged(async () => {
await this.onAccountSelected();
});
this._subscriptions.onValueChanged(async () => {
await this.onSubscriptionSelected();
});
this._groups.onValueChanged(async () => {
await this.onGroupSelected();
});
this._workspaces.onValueChanged(async () => {
await this.onWorkspaceSelectedChanged();
});
this._form = this._modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._accounts
}, {
title: constants.azureSubscription,
component: this._subscriptions
}, {
title: constants.azureGroup,
component: this._groups
}, {
title: constants.azureModelWorkspace,
component: this._workspaces
}]).component();
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
formBuilder.addFormItems([{
title: constants.azureAccount,
component: this._accounts
}, {
title: constants.azureSubscription,
component: this._subscriptions
}, {
title: constants.azureGroup,
component: this._groups
}, {
title: constants.azureModelWorkspace,
component: this._workspaces
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._accounts && this._subscriptions && this._groups && this._workspaces) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._accounts
});
formBuilder.removeFormItem({
title: constants.azureSubscription,
component: this._subscriptions
});
formBuilder.removeFormItem({
title: constants.azureGroup,
component: this._groups
});
formBuilder.removeFormItem({
title: constants.azureModelWorkspace,
component: this._workspaces
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component {
return this._form;
}
/**
* Returns selected data
*/
public get data(): AzureWorkspaceResource | undefined {
return {
account: this.account,
subscription: this.subscription,
group: this.group,
workspace: this.workspace
};
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
this._azureAccounts = await this.listAzureAccounts();
if (this._azureAccounts && this._azureAccounts.length > 0) {
let values = this._azureAccounts.map(a => { return { displayName: a.displayInfo.displayName, name: a.key.accountId }; });
this._accounts.values = values;
this._accounts.value = values[0];
}
await this.onAccountSelected();
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
private async onAccountSelected(): Promise<void> {
this._azureSubscriptions = await this.listAzureSubscriptions(this.account);
if (this._azureSubscriptions && this._azureSubscriptions.length > 0) {
let values = this._azureSubscriptions.map(s => { return { displayName: s.name, name: s.id }; });
this._subscriptions.values = values;
this._subscriptions.value = values[0];
}
await this.onSubscriptionSelected();
}
private async onSubscriptionSelected(): Promise<void> {
this._azureGroups = await this.listAzureGroups(this.account, this.subscription);
if (this._azureGroups && this._azureGroups.length > 0) {
let values = this._azureGroups.map(s => { return { displayName: s.name, name: s.id }; });
this._groups.values = values;
this._groups.value = values[0];
}
await this.onGroupSelected();
}
private async onGroupSelected(): Promise<void> {
this._azureWorkspaces = await this.listWorkspaces(this.account, this.subscription, this.group);
if (this._azureWorkspaces && this._azureWorkspaces.length > 0) {
let values = this._azureWorkspaces.map(s => { return { displayName: s.name || '', name: s.id || '' }; });
this._workspaces.values = values;
this._workspaces.value = values[0];
}
this.onWorkspaceSelectedChanged();
}
private onWorkspaceSelectedChanged(): void {
this._onWorkspacesSelectedChanged.fire();
}
private get workspace(): Workspace | undefined {
return this._azureWorkspaces && this._workspaces.value ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
}
private get account(): azdata.Account | undefined {
return this._azureAccounts && this._accounts.value ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
}
private get group(): azureResource.AzureResource | undefined {
return this._azureGroups && this._groups.value ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
}
private get subscription(): azureResource.AzureResourceSubscription | undefined {
return this._azureSubscriptions && this._subscriptions.value ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
}
}

View File

@@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SignInToAzureEventName } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 300;
export class AzureSignInComponent extends ModelViewBase {
private _form: azdata.FormContainer;
private _signInButton: azdata.ButtonComponent;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
this._signInButton = this._modelBuilder.button().withProperties({
width: componentWidth,
label: constants.azureSignIn,
}).component();
this._signInButton.onDidClick(() => {
this.sendRequest(SignInToAzureEventName);
});
this._form = this._modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._signInButton
}]).component();
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._signInButton) {
formBuilder.addFormItems([{
title: constants.azureAccount,
component: this._signInButton
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._signInButton) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._signInButton
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component {
return this._form;
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,128 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
/**
* View to pick local models file
*/
export class LocalModelsComponent extends ModelViewBase implements IDataComponent<string[]> {
private _form: azdata.FormContainer | undefined;
private _flex: azdata.FlexContainer | undefined;
private _localPath: azdata.InputBoxComponent | undefined;
private _localBrowse: azdata.ButtonComponent | undefined;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._localPath = modelBuilder.inputBox().withProperties({
value: '',
width: this.componentMaxLength - this.browseButtonMaxLength - this.spaceBetweenComponentsLength
}).component();
this._localBrowse = modelBuilder.button().withProperties({
label: constants.browseModels,
width: this.browseButtonMaxLength,
CSSStyles: {
'text-align': 'end'
}
}).component();
this._localBrowse.onDidClick(async () => {
let options: vscode.OpenDialogOptions = {
canSelectFiles: true,
canSelectFolders: false,
canSelectMany: this._multiSelect,
filters: { 'ONNX File': ['onnx'] }
};
const filePaths = await this.getLocalPaths(options);
if (this._localPath && filePaths && filePaths.length > 0) {
this._localPath.value = this._multiSelect ? filePaths.join(';') : filePaths[0];
} else if (this._localPath) {
this._localPath.value = '';
}
});
this._flex = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between',
width: this.componentMaxLength
}).withItems([
this._localPath, this._localBrowse]
).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._flex
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.addFormItem({
title: '',
component: this._flex
});
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
formBuilder.removeFormItem({
title: '',
component: this._flex
});
}
}
/**
* Returns selected data
*/
public get data(): string[] {
if (this._localPath?.value) {
return this._localPath?.value.split(';');
} else {
return [];
}
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
/**
* Returns the page title
*/
public get title(): string {
return constants.localModelsTitle;
}
}

View File

@@ -0,0 +1,148 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { CurrentModelsTable } from './currentModelsTable';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IPageView, IComponentSettings } from '../../interfaces';
import { TableSelectionComponent } from '../tableSelectionComponent';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* View to render current registered models
*/
export class CurrentModelsComponent extends ModelViewBase implements IPageView {
private _tableComponent: azdata.Component | undefined;
private _dataTable: CurrentModelsTable | undefined;
private _loader: azdata.LoadingComponent | undefined;
private _tableSelectionComponent: TableSelectionComponent | undefined;
/**
*
* @param apiWrapper Creates new view
* @param parent page parent
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
this._tableSelectionComponent.registerComponent(modelBuilder);
this._tableSelectionComponent.onSelectedChanged(async () => {
await this.onTableSelected();
});
this._dataTable = new CurrentModelsTable(this._apiWrapper, this, this._settings);
this._dataTable.registerComponent(modelBuilder);
this._tableComponent = this._dataTable.component;
let formModelBuilder = modelBuilder.formContainer();
this._tableSelectionComponent.addComponents(formModelBuilder);
if (this._tableComponent) {
formModelBuilder.addFormItem({
component: this._tableComponent,
title: ''
});
}
this._loader = modelBuilder.loadingComponent()
.withItem(formModelBuilder.component())
.withProperties({
loading: true
}).component();
return this._loader;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._tableSelectionComponent && this._dataTable) {
this._tableSelectionComponent.addComponents(formBuilder);
this._dataTable.addComponents(formBuilder);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._tableSelectionComponent && this._dataTable) {
this._tableSelectionComponent.removeComponents(formBuilder);
this._dataTable.removeComponents(formBuilder);
}
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._loader;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.onLoading();
try {
if (this._tableSelectionComponent) {
this._tableSelectionComponent.refresh();
}
await this._dataTable?.refresh();
} catch (err) {
this.showErrorMessage(constants.getErrorMessage(err));
} finally {
await this.onLoaded();
}
}
public get data(): ImportedModel[] | undefined {
return this._dataTable?.data;
}
private async onTableSelected(): Promise<void> {
if (this._tableSelectionComponent?.data) {
this.importTable = this._tableSelectionComponent?.data;
await this.storeImportConfigTable();
await this._dataTable?.refresh();
}
}
public get modelTable(): CurrentModelsTable | undefined {
return this._dataTable;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._dataTable) {
await this._dataTable.disposeComponent();
}
}
/**
* returns the title of the page
*/
public get title(): string {
return constants.currentModelsTitle;
}
private async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
private async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
}

View File

@@ -0,0 +1,314 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../../common/constants';
import { ModelViewBase, DeleteModelEventName, EditModelEventName } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { IDataComponent, IComponentSettings } from '../../interfaces';
import { ModelArtifact } from '../prediction/modelArtifact';
import * as utils from '../../../common/utils';
/**
* View to render registered models table
*/
export class CurrentModelsTable extends ModelViewBase implements IDataComponent<ImportedModel[]> {
private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined;
private _selectedModel: ImportedModel[] = [];
private _loader: azdata.LoadingComponent | undefined;
private _downloadedFile: ModelArtifact | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: IComponentSettings) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._modelBuilder = modelBuilder;
let columns = [
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelCreated,
ariaLabel: constants.modelCreated,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
if (this._settings.editable) {
columns.push(
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
);
}
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: columns,
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
this._loader = modelBuilder.loadingComponent()
.withItem(this._table)
.withProperties({
loading: true
}).component();
return this._loader;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this.component) {
formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this.component) {
formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component });
}
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._loader;
}
/**
* Loads the data in the component
*/
public async loadData(): Promise<void> {
await this.onLoading();
if (this._table) {
let models: ImportedModel[] | undefined;
if (this.importTable) {
models = await this.listModels(this.importTable);
} else {
this.showErrorMessage('No import table');
}
let tableData: any[][] = [];
if (models) {
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
}
this._table.data = tableData;
}
this.onModelSelected();
await this.onLoaded();
}
public async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
public async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
private createTableRow(model: ImportedModel): any[] {
let row: any[] = [model.modelName, model.created];
if (this._modelBuilder) {
const selectButton = this.createSelectButton(model);
if (selectButton) {
row.push(selectButton);
}
const editButtons = this.createEditButtons(model);
if (editButtons && editButtons.length > 0) {
row = row.concat(editButtons);
}
}
return row;
}
private createSelectButton(model: ImportedModel): azdata.Component | undefined {
let selectModelButton: azdata.Component | undefined = undefined;
if (this._modelBuilder && this._settings.selectable) {
let onSelectItem = (checked: boolean) => {
if (!this._settings.multiSelect) {
this._selectedModel = [];
}
const foundItem = this._selectedModel.find(x => x === model);
if (checked && !foundItem) {
this._selectedModel.push(model);
} else if (foundItem) {
this._selectedModel = this._selectedModel.filter(x => x !== model);
}
this.onModelSelected();
};
if (this._settings.multiSelect) {
const checkbox = this._modelBuilder.checkBox().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
checkbox.onChanged(() => {
onSelectItem(checkbox.checked || false);
});
selectModelButton = checkbox;
} else {
const radioButton = this._modelBuilder.radioButton().withProperties({
name: 'amlModel',
value: model.id,
width: 15,
height: 15,
checked: false
}).component();
radioButton.onDidClick(() => {
onSelectItem(radioButton.checked || false);
});
selectModelButton = radioButton;
}
}
return selectModelButton;
}
private createEditButtons(model: ImportedModel): azdata.Component[] | undefined {
let dropButton: azdata.ButtonComponent | undefined = undefined;
let editButton: azdata.ButtonComponent | undefined = undefined;
if (this._modelBuilder && this._settings.editable) {
dropButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
width: 15,
height: 15
}).component();
dropButton.onDidClick(async () => {
try {
const confirm = await utils.promptConfirm(constants.confirmDeleteModel(model.modelName), this._apiWrapper);
if (confirm) {
await this.sendDataRequest(DeleteModelEventName, model);
if (this.parent) {
await this.parent?.refresh();
}
}
} catch (error) {
this.showErrorMessage(`${constants.updateModelFailedError} ${constants.getErrorMessage(error)}`);
}
});
editButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
iconPath: {
dark: this.asAbsolutePath('images/dark/edit_inverse.svg'),
light: this.asAbsolutePath('images/light/edit.svg')
},
width: 15,
height: 15
}).component();
editButton.onDidClick(async () => {
await this.sendDataRequest(EditModelEventName, model);
});
}
return editButton && dropButton ? [editButton, dropButton] : undefined;
}
private async onModelSelected(): Promise<void> {
this._onModelSelectionChanged.fire();
if (this._downloadedFile) {
await this._downloadedFile.close();
}
this._downloadedFile = undefined;
}
/**
* Returns selected data
*/
public get data(): ImportedModel[] | undefined {
return this._selectedModel;
}
public async getDownloadedModel(): Promise<ModelArtifact | undefined> {
if (!this._downloadedFile && this.data && this.data.length > 0) {
this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data[0]));
}
return this._downloadedFile;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,75 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { ModelViewBase, UpdateModelEventName } from '../modelViewBase';
import * as constants from '../../../common/constants';
import { ApiWrapper } from '../../../common/apiWrapper';
import { DialogView } from '../../dialogView';
import { ModelDetailsEditPage } from './modelDetailsEditPage';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* Dialog to render registered model views
*/
export class EditModelDialog extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string,
private _parentView: ModelViewBase | undefined,
private _model: ImportedModel) {
super(apiWrapper, root);
this.dialogView = new DialogView(this._apiWrapper);
}
public dialogView: DialogView;
public editModelPage: ModelDetailsEditPage | undefined;
/**
* Opens a dialog to edit models.
*/
public open(): void {
this.editModelPage = new ModelDetailsEditPage(this._apiWrapper, this, this._model);
let registerModelButton = this._apiWrapper.createButton(constants.extLangSaveButtonText);
registerModelButton.onClick(async () => {
if (this.editModelPage) {
const valid = await this.editModelPage.validate();
if (valid) {
try {
await this.sendDataRequest(UpdateModelEventName, this.editModelPage?.data);
this.showInfoMessage(constants.modelUpdatedSuccessfully);
if (this._parentView) {
await this._parentView.refresh();
}
} catch (error) {
this.showInfoMessage(`${constants.modelUpdateFailedError} ${constants.getErrorMessage(error)}`);
}
}
}
});
let dialog = this.dialogView.createDialog(constants.editModelTitle, [this.editModelPage]);
dialog.customButtons = [registerModelButton];
this.mainViewPanel = dialog;
dialog.okButton.hidden = true;
dialog.cancelButton.label = constants.extLangDoneButtonText;
dialog.registerCloseValidator(() => {
return false; // Blocks Enter key from closing dialog.
});
this._apiWrapper.openDialog(dialog);
}
/**
* Resets the tabs for given provider Id
*/
public async refresh(): Promise<void> {
if (this.dialogView) {
this.dialogView.refresh();
}
}
}

View File

@@ -0,0 +1,113 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { ModelSourcesComponent } from '../modelSourcesComponent';
import { LocalModelsComponent } from '../localModelsComponent';
import { AzureModelsComponent } from '../azureModelsComponent';
import * as constants from '../../../common/constants';
import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage';
import { ModelDetailsPage } from '../modelDetailsPage';
import { ModelBrowsePage } from '../modelBrowsePage';
import { ModelImportLocationPage } from './modelImportLocationPage';
/**
* Wizard to register a model
*/
export class ImportModelWizard extends ModelViewBase {
public modelSourcePage: ModelSourcePage | undefined;
public modelBrowsePage: ModelBrowsePage | undefined;
public modelDetailsPage: ModelDetailsPage | undefined;
public modelImportTargetPage: ModelImportLocationPage | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
constructor(
apiWrapper: ApiWrapper,
root: string,
parent?: ModelViewBase) {
super(apiWrapper, root);
this._parentView = parent;
}
/**
* Opens a dialog to manage packages used by notebooks.
*/
public async open(): Promise<void> {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
this.modelBrowsePage = new ModelBrowsePage(this._apiWrapper, this);
this.modelImportTargetPage = new ModelImportLocationPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
let wizard = this.wizardView.createWizard(constants.registerModelTitle, [this.modelImportTargetPage, this.modelSourcePage, this.modelBrowsePage, this.modelDetailsPage]);
this.mainViewPanel = wizard;
wizard.doneButton.label = constants.azureRegisterModel;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
let validated: boolean = true;
if (pageInfo.newPage > pageInfo.lastPage) {
validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
}
if (validated && pageInfo.newPage === undefined) {
wizard.cancelButton.enabled = false;
wizard.backButton.enabled = false;
let result = await this.registerModel();
wizard.cancelButton.enabled = true;
wizard.backButton.enabled = true;
if (this._parentView) {
this._parentView.importTable = this.importTable;
await this._parentView.refresh();
}
return result;
}
return validated;
});
await wizard.open();
}
public get modelResources(): ModelSourcesComponent | undefined {
return this.modelSourcePage?.modelResources;
}
public get localModelsComponent(): LocalModelsComponent | undefined {
return this.modelBrowsePage?.localModelsComponent;
}
public get azureModelsComponent(): AzureModelsComponent | undefined {
return this.modelBrowsePage?.azureModelsComponent;
}
private async registerModel(): Promise<boolean> {
try {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
await this.importLocalModel(this.modelsViewData);
} else {
await this.importAzureModel(this.modelsViewData);
}
await this.storeImportConfigTable();
this.showInfoMessage(constants.modelRegisteredSuccessfully);
return true;
} catch (error) {
this.showErrorMessage(`${constants.modelFailedToRegister} ${constants.getErrorMessage(error)}`);
return false;
}
}
/**
* Refresh the pages
*/
public async refresh(): Promise<void> {
await this.wizardView?.refresh();
}
}

View File

@@ -0,0 +1,63 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { CurrentModelsComponent } from './currentModelsComponent';
import { ModelViewBase, RegisterModelEventName } from '../modelViewBase';
import * as constants from '../../../common/constants';
import { ApiWrapper } from '../../../common/apiWrapper';
import { DialogView } from '../../dialogView';
/**
* Dialog to render registered model views
*/
export class ManageModelsDialog extends ModelViewBase {
constructor(
apiWrapper: ApiWrapper,
root: string) {
super(apiWrapper, root);
this.dialogView = new DialogView(this._apiWrapper);
}
public dialogView: DialogView;
public currentLanguagesTab: CurrentModelsComponent | undefined;
/**
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
this.currentLanguagesTab = new CurrentModelsComponent(this._apiWrapper, this, {
editable: true,
selectable: false
});
let registerModelButton = this._apiWrapper.createButton(constants.importModelTitle);
registerModelButton.onClick(async () => {
await this.sendDataRequest(RegisterModelEventName, this.currentLanguagesTab?.modelTable?.importTable);
});
let dialog = this.dialogView.createDialog(constants.registerModelTitle, [this.currentLanguagesTab]);
dialog.customButtons = [registerModelButton];
this.mainViewPanel = dialog;
dialog.okButton.hidden = true;
dialog.cancelButton.label = constants.extLangDoneButtonText;
dialog.registerCloseValidator(() => {
return false; // Blocks Enter key from closing dialog.
});
this._apiWrapper.openDialog(dialog);
}
/**
* Resets the tabs for given provider Id
*/
public async refresh(): Promise<void> {
if (this.dialogView) {
this.dialogView.refresh();
}
}
}

View File

@@ -0,0 +1,154 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { ImportedModel } from '../../../modelManagement/interfaces';
/**
* View to render filters to pick an azure resource
*/
export class ModelDetailsComponent extends ModelViewBase implements IDataComponent<ImportedModel> {
private _form: azdata.FormContainer | undefined;
private _nameComponent: azdata.InputBoxComponent | undefined;
private _descriptionComponent: azdata.InputBoxComponent | undefined;
private _createdComponent: azdata.Component | undefined;
private _deployedComponent: azdata.Component | undefined;
private _frameworkComponent: azdata.Component | undefined;
private _frameworkVersionComponent: azdata.Component | undefined;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._createdComponent = modelBuilder.text().withProperties({
value: this._model.created
}).component();
this._deployedComponent = modelBuilder.text().withProperties({
value: this._model.deploymentTime
}).component();
this._frameworkComponent = modelBuilder.text().withProperties({
value: this._model.framework
}).component();
this._frameworkVersionComponent = modelBuilder.text().withProperties({
value: this._model.frameworkVersion
}).component();
this._nameComponent = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength,
value: this._model.modelName
}).component();
this._descriptionComponent = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength,
value: this._model.description,
multiline: true,
height: 50
}).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._nameComponent
},
{
title: '',
component: this._descriptionComponent
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
formBuilder.addFormItems([{
title: constants.modelName,
component: this._nameComponent
}, {
title: constants.modelCreated,
component: this._createdComponent
},
{
title: constants.modelDeployed,
component: this._deployedComponent
}, {
title: constants.modelFramework,
component: this._frameworkComponent
}, {
title: constants.modelFrameworkVersion,
component: this._frameworkVersionComponent
}, {
title: constants.modelDescription,
component: this._descriptionComponent
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._nameComponent && this._descriptionComponent && this._createdComponent && this._deployedComponent && this._frameworkComponent && this._frameworkVersionComponent) {
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._createdComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._frameworkComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._frameworkVersionComponent
});
formBuilder.removeFormItem({
title: constants.modelCreated,
component: this._deployedComponent
});
formBuilder.removeFormItem({
title: constants.modelName,
component: this._nameComponent
});
formBuilder.removeFormItem({
title: constants.modelDescription,
component: this._descriptionComponent
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Returns selected data
*/
public get data(): ImportedModel | undefined {
let model = Object.assign({}, this._model);
model.modelName = this._nameComponent?.value || '';
model.description = this._descriptionComponent?.value || '';
return model;
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,85 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { ImportedModel } from '../../../modelManagement/interfaces';
import { ModelDetailsComponent } from './modelDetailsComponent';
/**
* View to pick model source
*/
export class ModelDetailsEditPage extends ModelViewBase implements IPageView, IDataComponent<ImportedModel> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelDetailsComponent: ModelDetailsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _model: ImportedModel) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelDetailsComponent = new ModelDetailsComponent(this._apiWrapper, this, this._model);
this.modelDetailsComponent.registerComponent(modelBuilder);
this.modelDetailsComponent.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ImportedModel | undefined {
return this.modelDetailsComponent?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this.modelDetailsComponent) {
await this.modelDetailsComponent.refresh();
}
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelImportTargetPageTitle;
}
public async disposePage(): Promise<void> {
}
public async validate(): Promise<boolean> {
let validated = false;
if (this.data?.modelName) {
validated = true;
} else {
this.showErrorMessage(constants.modelNameRequiredError);
}
return validated;
}
}

View File

@@ -0,0 +1,97 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { TableSelectionComponent } from '../tableSelectionComponent';
import { DatabaseTable } from '../../../prediction/interfaces';
/**
* View to pick model source
*/
export class ModelImportLocationPage extends ModelViewBase implements IPageView, IDataComponent<DatabaseTable> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public tableSelectionComponent: TableSelectionComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, true);
this.tableSelectionComponent.onSelectedChanged(async () => {
await this.onTableSelected();
});
this.tableSelectionComponent.registerComponent(modelBuilder);
this.tableSelectionComponent.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
private async onTableSelected(): Promise<void> {
if (this.tableSelectionComponent?.data) {
this.importTable = this.tableSelectionComponent?.data;
}
}
/**
* Returns selected data
*/
public get data(): DatabaseTable | undefined {
return this.tableSelectionComponent?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this.tableSelectionComponent) {
await this.tableSelectionComponent.refresh();
}
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelImportTargetPageTitle;
}
public async disposePage(): Promise<void> {
}
public async validate(): Promise<boolean> {
let validated = false;
if (this.data?.databaseName && this.data?.tableName) {
validated = true;
validated = await this.verifyImportConfigTable(this.data);
if (!validated) {
this.showErrorMessage(constants.invalidImportTableSchemaError(this.data?.databaseName, this.data?.tableName));
}
} else {
this.showErrorMessage(constants.invalidImportTableError(this.data?.databaseName, this.data?.tableName));
}
return validated;
}
}

View File

@@ -0,0 +1,207 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType, ModelViewData } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { LocalModelsComponent } from './localModelsComponent';
import { AzureModelsComponent } from './azureModelsComponent';
import * as utils from '../../common/utils';
import { CurrentModelsComponent } from './manageModels/currentModelsComponent';
/**
* View to pick model source
*/
export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
private _form: azdata.FormContainer | undefined;
private _title: string = constants.localModelPageTitle;
private _formBuilder: azdata.FormBuilder | undefined;
public localModelsComponent: LocalModelsComponent | undefined;
public azureModelsComponent: AzureModelsComponent | undefined;
public registeredModelsComponent: CurrentModelsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _multiSelect: boolean = true) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.localModelsComponent = new LocalModelsComponent(this._apiWrapper, this, this._multiSelect);
this.localModelsComponent.registerComponent(modelBuilder);
this.azureModelsComponent = new AzureModelsComponent(this._apiWrapper, this, this._multiSelect);
this.azureModelsComponent.registerComponent(modelBuilder);
this.registeredModelsComponent = new CurrentModelsComponent(this._apiWrapper, this, {
selectable: true,
multiSelect: this._multiSelect,
editable: false
});
this.registeredModelsComponent.registerComponent(modelBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelViewData[] {
return this.modelsViewData;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.modelSourceType === ModelSourceType.Local) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.azureModelsComponent.removeComponents(this._formBuilder);
this.registeredModelsComponent.removeComponents(this._formBuilder);
this.localModelsComponent.addComponents(this._formBuilder);
await this.localModelsComponent.refresh();
}
} else if (this.modelSourceType === ModelSourceType.Azure) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.localModelsComponent.removeComponents(this._formBuilder);
this.azureModelsComponent.addComponents(this._formBuilder);
this.registeredModelsComponent.removeComponents(this._formBuilder);
await this.azureModelsComponent.refresh();
}
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
if (this.localModelsComponent && this.azureModelsComponent && this.registeredModelsComponent) {
this.localModelsComponent.removeComponents(this._formBuilder);
this.azureModelsComponent.removeComponents(this._formBuilder);
this.registeredModelsComponent.addComponents(this._formBuilder);
await this.registeredModelsComponent.refresh();
}
}
}
this.loadTitle();
}
private loadTitle(): void {
if (this.modelSourceType === ModelSourceType.Local) {
this._title = constants.localModelPageTitle;
} else if (this.modelSourceType === ModelSourceType.Azure) {
this._title = constants.azureModelPageTitle;
} else if (this.modelSourceType === ModelSourceType.RegisteredModels) {
this._title = constants.importedModelsPageTitle;
} else {
this._title = constants.modelSourcePageTitle;
}
}
/**
* Returns page title
*/
public get title(): string {
this.loadTitle();
return this._title;
}
public validate(): Promise<boolean> {
let validated = false;
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
validated = this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0;
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
validated = this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0;
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
validated = this.registeredModelsComponent.data !== undefined && this.registeredModelsComponent.data.length > 0;
}
if (!validated) {
this.showErrorMessage(constants.invalidModelToSelectError);
}
return Promise.resolve(validated);
}
public onEnter(): Promise<void> {
return Promise.resolve();
}
public async onLeave(): Promise<void> {
this.modelsViewData = [];
if (this.modelSourceType === ModelSourceType.Local && this.localModelsComponent) {
if (this.localModelsComponent.data !== undefined && this.localModelsComponent.data.length > 0) {
this.modelsViewData = this.localModelsComponent.data.map(x => {
const fileName = utils.getFileName(x);
return {
modelData: x,
modelDetails: {
modelName: fileName,
fileName: fileName
},
targetImportTable: this.importTable
};
});
}
} else if (this.modelSourceType === ModelSourceType.Azure && this.azureModelsComponent) {
if (this.azureModelsComponent.data !== undefined && this.azureModelsComponent.data.length > 0) {
this.modelsViewData = this.azureModelsComponent.data.map(x => {
return {
modelData: {
account: x.account,
subscription: x.subscription,
group: x.group,
workspace: x.workspace,
model: x.model
},
modelDetails: {
modelName: x.model?.name || '',
fileName: x.model?.name,
framework: x.model?.framework,
frameworkVersion: x.model?.frameworkVersion,
created: x.model?.createdTime
},
targetImportTable: this.importTable
};
});
}
} else if (this.modelSourceType === ModelSourceType.RegisteredModels && this.registeredModelsComponent) {
if (this.registeredModelsComponent.data !== undefined) {
this.modelsViewData = this.registeredModelsComponent.data.map(x => {
return {
modelData: x,
modelDetails: {
modelName: ''
},
targetImportTable: this.importTable
};
});
}
}
}
public async disposePage(): Promise<void> {
if (this.azureModelsComponent) {
await this.azureModelsComponent.disposeComponent();
}
if (this.registeredModelsComponent) {
await this.registeredModelsComponent.disposeComponent();
}
}
}

View File

@@ -0,0 +1,83 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelViewData } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { ModelsDetailsTableComponent } from './modelsDetailsTableComponent';
/**
* View to pick model details
*/
export class ModelDetailsPage extends ModelViewBase implements IPageView, IDataComponent<ModelViewData[]> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelDetails: ModelsDetailsTableComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelDetails = new ModelsDetailsTableComponent(this._apiWrapper, modelBuilder, this);
this.modelDetails.registerComponent(modelBuilder);
this.modelDetails.addComponents(this._formBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelViewData[] | undefined {
return this.modelDetails?.data;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this.modelDetails) {
await this.modelDetails.refresh();
}
}
public async onEnter(): Promise<void> {
await this.refresh();
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelDetailsPageTitle;
}
public validate(): Promise<boolean> {
if (this.data && this.data.length > 0 && !this.data.find(x => !x.modelDetails?.modelName)) {
return Promise.resolve(true);
} else {
this.showErrorMessage(constants.modelNameRequiredError);
return Promise.resolve(false);
}
}
}

View File

@@ -0,0 +1,425 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { ImportedModel, WorkspaceModel, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { ManageModelsDialog } from './manageModels/manageModelsDialog';
import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterAzureModelEventName,
ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName, ModelSourceType, ModelViewData, StoreImportTableEventName, VerifyImportTableEventName, EditModelEventName, UpdateModelEventName, DeleteModelEventName, SignInToAzureEventName
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { ImportModelWizard } from './manageModels/importModelWizard';
import * as fs from 'fs';
import * as constants from '../../common/constants';
import { PredictWizard } from './prediction/predictWizard';
import { AzureModelResource } from '../interfaces';
import { PredictService } from '../../prediction/predictService';
import { EditModelDialog } from './manageModels/editModelDialog';
/**
* Model management UI controller
*/
export class ModelManagementController extends ControllerBase {
/**
* Creates new instance
*/
constructor(
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
private _registeredModelService: DeployedModelService,
private _predictService: PredictService) {
super(apiWrapper);
}
/**
* Opens the dialog for model registration
* @param parent parent if the view is opened from another view
* @param controller controller
* @param apiWrapper apiWrapper
* @param root root folder path
*/
public async registerModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
controller = controller || this;
apiWrapper = apiWrapper || this._apiWrapper;
root = root || this._root;
let view = new ImportModelWizard(apiWrapper, root, parent);
if (importTable) {
view.importTable = importTable;
} else {
view.importTable = await controller._registeredModelService.getRecentImportTable();
}
controller.registerEvents(view);
// Open view
//
await view.open();
await view.refresh();
return view;
}
/**
* Opens the dialog to edit model
*/
public async editModel(model: ImportedModel, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise<ModelViewBase> {
controller = controller || this;
apiWrapper = apiWrapper || this._apiWrapper;
root = root || this._root;
let view = new EditModelDialog(apiWrapper, root, parent, model);
controller.registerEvents(view);
// Open view
//
await view.open();
await view.refresh();
return view;
}
/**
* Opens the wizard for prediction
*/
public async predictModel(): Promise<ModelViewBase> {
let view = new PredictWizard(this._apiWrapper, this._root);
view.importTable = await this._registeredModelService.getRecentImportTable();
this.registerEvents(view);
view.on(LoadModelParametersEventName, async () => {
const modelArtifact = await view.getModelFileName();
await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
modelArtifact?.filePath);
});
// Open view
//
await view.open();
await view.refresh();
return view;
}
/**
* Register events in the main view
* @param view main view
*/
public registerEvents(view: ModelViewBase): void {
// Register events
//
super.registerEvents(view);
view.on(ListAccountsEventName, async () => {
await this.executeAction(view, ListAccountsEventName, this.getAzureAccounts, this._amlService);
});
view.on(ListSubscriptionsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListSubscriptionsEventName, this.getAzureSubscriptions, this._amlService, azureArgs.account);
});
view.on(ListWorkspacesEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListWorkspacesEventName, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group);
});
view.on(ListGroupsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListGroupsEventName, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription);
});
view.on(ListAzureModelsEventName, async (arg) => {
let azureArgs = <AzureResourceEventArgs>arg;
await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService
, azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace);
});
view.on(ListModelsEventName, async (args) => {
const table = <DatabaseTable>args;
await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService, table);
});
view.on(RegisterLocalModelEventName, async (arg) => {
let models = <ModelViewData[]>arg;
await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models);
view.refresh();
});
view.on(RegisterModelEventName, async (args) => {
const importTable = <DatabaseTable>args;
await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root);
});
view.on(EditModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, EditModelEventName, this.editModel, model, view, this, this._apiWrapper, this._root);
});
view.on(UpdateModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, UpdateModelEventName, this.updateModel, this._registeredModelService, model);
});
view.on(DeleteModelEventName, async (args) => {
const model = <ImportedModel>args;
await this.executeAction(view, DeleteModelEventName, this.deleteModel, this._registeredModelService, model);
});
view.on(RegisterAzureModelEventName, async (arg) => {
let models = <ModelViewData[]>arg;
await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService,
models);
});
view.on(DownloadAzureModelEventName, async (arg) => {
let registerArgs = <AzureModelResource>arg;
await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService,
registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model);
});
view.on(ListDatabaseNamesEventName, async () => {
await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService);
});
view.on(ListTableNamesEventName, async (arg) => {
let dbName = <string>arg;
await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName);
});
view.on(ListColumnNamesEventName, async (arg) => {
let tableColumnsArgs = <DatabaseTable>arg;
await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService,
tableColumnsArgs);
});
view.on(PredictModelEventName, async (arg) => {
let predictArgs = <PredictModelEventArgs>arg;
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
predictArgs, predictArgs.model, predictArgs.filePath);
});
view.on(DownloadRegisteredModelEventName, async (arg) => {
let model = <ImportedModel>arg;
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
model);
});
view.on(StoreImportTableEventName, async (arg) => {
let importTable = <DatabaseTable>arg;
await this.executeAction(view, StoreImportTableEventName, this.storeImportTable, this._registeredModelService,
importTable);
});
view.on(VerifyImportTableEventName, async (arg) => {
let importTable = <DatabaseTable>arg;
await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService,
importTable);
});
view.on(SourceModelSelectedEventName, async (arg) => {
view.modelSourceType = <ModelSourceType>arg;
await view.refresh();
});
view.on(SignInToAzureEventName, async () => {
await this.executeAction(view, SignInToAzureEventName, this.signInToAzure, this._amlService);
await view.refresh();
});
}
/**
* Opens the dialog for model management
*/
public async manageRegisteredModels(importTable?: DatabaseTable): Promise<ModelViewBase> {
let view = new ManageModelsDialog(this._apiWrapper, this._root);
if (importTable) {
view.importTable = importTable;
} else {
view.importTable = await this._registeredModelService.getRecentImportTable();
}
// Register events
//
this.registerEvents(view);
// Open view
//
view.open();
return view;
}
private async signInToAzure(service: AzureModelRegistryService): Promise<void> {
return await service.signInToAzure();
}
private async getAzureAccounts(service: AzureModelRegistryService): Promise<azdata.Account[]> {
return await service.getAccounts();
}
private async getAzureSubscriptions(service: AzureModelRegistryService, account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[] | undefined> {
return await service.getSubscriptions(account);
}
private async getAzureGroups(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[] | undefined> {
return await service.getGroups(account, subscription);
}
private async getWorkspaces(service: AzureModelRegistryService, account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[] | undefined> {
if (!account || !subscription) {
return [];
}
return await service.getWorkspaces(account, subscription, group);
}
private async getRegisteredModels(registeredModelService: DeployedModelService, table: DatabaseTable): Promise<ImportedModel[]> {
return registeredModelService.getDeployedModels(table);
}
private async getAzureModels(
service: AzureModelRegistryService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined): Promise<WorkspaceModel[]> {
if (!account || !subscription || !resourceGroup || !workspace) {
return [];
}
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: DeployedModelService, models: ModelViewData[] | undefined): Promise<void> {
if (models) {
await Promise.all(models.map(async (model) => {
if (model && model.targetImportTable) {
const localModel = <string>model.modelData;
if (localModel) {
await service.deployLocalModel(localModel, model.modelDetails, model.targetImportTable);
}
} else {
throw Error(constants.invalidModelToRegisterError);
}
}));
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async updateModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
if (model) {
await service.updateModel(model);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async deleteModel(service: DeployedModelService, model: ImportedModel | undefined): Promise<void> {
if (model) {
await service.deleteModel(model);
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
private async registerAzureModel(
azureService: AzureModelRegistryService,
service: DeployedModelService,
models: ModelViewData[] | undefined): Promise<void> {
if (!models) {
throw Error(constants.invalidAzureResourceError);
}
await Promise.all(models.map(async (model) => {
if (model && model.targetImportTable) {
const azureModel = <AzureModelResource>model.modelData;
if (azureModel && azureModel.account && azureModel.subscription && azureModel.group && azureModel.workspace && azureModel.model) {
let filePath: string | undefined;
try {
const filePath = await azureService.downloadModel(azureModel.account, azureModel.subscription, azureModel.group,
azureModel.workspace, azureModel.model);
if (filePath) {
await service.deployLocalModel(filePath, model.modelDetails, model.targetImportTable);
} else {
throw Error(constants.invalidModelToRegisterError);
}
} finally {
if (filePath) {
await fs.promises.unlink(filePath);
}
}
}
} else {
throw Error(constants.invalidModelToRegisterError);
}
}));
}
private async getDatabaseList(predictService: PredictService): Promise<string[]> {
return await predictService.getDatabaseList();
}
private async getTableList(predictService: PredictService, databaseName: string): Promise<DatabaseTable[]> {
return await predictService.getTableList(databaseName);
}
private async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
return await predictService.getTableColumnsList(databaseTable);
}
private async generatePredictScript(
predictService: PredictService,
predictParams: PredictParameters,
registeredModel: ImportedModel | undefined,
filePath: string | undefined
): Promise<string> {
if (!predictParams) {
throw Error(constants.invalidModelToPredictError);
}
const result = await predictService.generatePredictScript(predictParams, registeredModel, filePath);
return result;
}
private async storeImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<void> {
if (table) {
await registeredModelService.storeRecentImportTable(table);
} else {
throw Error(constants.invalidImportTableError(undefined, undefined));
}
}
private async verifyImportTable(registeredModelService: DeployedModelService, table: DatabaseTable | undefined): Promise<boolean> {
if (table) {
return await registeredModelService.verifyConfigTable(table);
} else {
throw Error(constants.invalidImportTableError(undefined, undefined));
}
}
private async downloadRegisteredModel(
registeredModelService: DeployedModelService,
model: ImportedModel | undefined): Promise<string> {
if (!model) {
throw Error(constants.invalidModelToPredictError);
}
return await registeredModelService.downloadModel(model);
}
private async loadModelParameters(
registeredModelService: DeployedModelService,
model: string | undefined): Promise<ModelParameters | undefined> {
if (!model) {
return undefined;
}
return await registeredModelService.loadModelParameters(model);
}
private async downloadAzureModel(
azureService: AzureModelRegistryService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
workspace: Workspace | undefined,
model: WorkspaceModel | undefined): Promise<string> {
if (!account || !subscription || !resourceGroup || !workspace || !model) {
throw Error(constants.invalidAzureResourceError);
}
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
return filePath;
} else {
throw Error(constants.invalidModelToRegisterError);
}
}
}

View File

@@ -0,0 +1,69 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelSourceType } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IPageView, IDataComponent } from '../interfaces';
import { ModelSourcesComponent } from './modelSourcesComponent';
/**
* View to pick model source
*/
export class ModelSourcePage extends ModelViewBase implements IPageView, IDataComponent<ModelSourceType> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public modelResources: ModelSourcesComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.modelResources = new ModelSourcesComponent(this._apiWrapper, this, this._options);
this.modelResources.registerComponent(modelBuilder);
this.modelResources.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): ModelSourceType {
return this.modelResources?.data || ModelSourceType.Local;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
/**
* Returns page title
*/
public get title(): string {
return constants.modelSourcePageTitle;
}
public async disposePage(): Promise<void> {
}
}

View File

@@ -0,0 +1,156 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, SourceModelSelectedEventName, ModelSourceType } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
/**
* View to pick model source
*/
export class ModelSourcesComponent extends ModelViewBase implements IDataComponent<ModelSourceType> {
private _form: azdata.FormContainer | undefined;
private _flexContainer: azdata.FlexContainer | undefined;
private _amlModel: azdata.CardComponent | undefined;
private _localModel: azdata.CardComponent | undefined;
private _registeredModels: azdata.CardComponent | undefined;
private _sourceType: ModelSourceType = ModelSourceType.Local;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _options: ModelSourceType[] = [ModelSourceType.Local, ModelSourceType.Azure]) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._localModel = modelBuilder.card()
.withProperties({
value: 'local',
name: 'modelLocation',
label: constants.localModelSource,
selected: this._options[0] === ModelSourceType.Local,
cardType: azdata.CardType.VerticalButton,
width: 50
}).component();
this._amlModel = modelBuilder.card()
.withProperties({
value: 'aml',
name: 'modelLocation',
label: constants.azureModelSource,
selected: this._options[0] === ModelSourceType.Azure,
cardType: azdata.CardType.VerticalButton,
width: 50
}).component();
this._registeredModels = modelBuilder.card()
.withProperties({
value: 'registered',
name: 'modelLocation',
label: constants.registeredModelsSource,
selected: this._options[0] === ModelSourceType.RegisteredModels,
cardType: azdata.CardType.VerticalButton,
width: 50
}).component();
this._localModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Local;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._amlModel && this._registeredModels) {
this._amlModel.selected = false;
this._registeredModels.selected = false;
}
});
this._amlModel.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.Azure;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._registeredModels) {
this._localModel.selected = false;
this._registeredModels.selected = false;
}
});
this._registeredModels.onCardSelectedChanged(() => {
this._sourceType = ModelSourceType.RegisteredModels;
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
if (this._localModel && this._amlModel) {
this._localModel.selected = false;
this._amlModel.selected = false;
}
});
let components: azdata.Component[] = [];
this._options.forEach(option => {
switch (option) {
case ModelSourceType.Local:
if (this._localModel) {
components.push(this._localModel);
}
break;
case ModelSourceType.Azure:
if (this._amlModel) {
components.push(this._amlModel);
}
break;
case ModelSourceType.RegisteredModels:
if (this._registeredModels) {
components.push(this._registeredModels);
}
break;
}
});
this._sourceType = this._options[0];
this.sendRequest(SourceModelSelectedEventName, this._sourceType);
this._flexContainer = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between'
}).withItems(components).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._flexContainer
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
formBuilder.addFormItem({ title: '', component: this._flexContainer });
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flexContainer) {
formBuilder.removeFormItem({ title: '', component: this._flexContainer });
}
}
/**
* Returns selected data
*/
public get data(): ModelSourceType {
return this._sourceType;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,328 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { ViewBase } from '../viewBase';
import { ImportedModel, WorkspaceModel, ImportedModelDetails, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
export interface AzureResourceEventArgs extends AzureWorkspaceResource {
}
export interface RegisterModelEventArgs extends AzureWorkspaceResource {
details?: ImportedModelDetails
}
export interface PredictModelEventArgs extends PredictParameters {
model?: ImportedModel;
filePath?: string;
}
export enum ModelSourceType {
Local,
Azure,
RegisteredModels
}
export interface ModelViewData {
modelFile?: string;
modelData: AzureModelResource | string | ImportedModel;
modelDetails?: ImportedModelDetails;
targetImportTable?: DatabaseTable;
}
// Event names
//
export const ListModelsEventName = 'listModels';
export const ListAzureModelsEventName = 'listAzureModels';
export const ListAccountsEventName = 'listAccounts';
export const ListDatabaseNamesEventName = 'listDatabaseNames';
export const ListTableNamesEventName = 'listTableNames';
export const ListColumnNamesEventName = 'listColumnNames';
export const ListSubscriptionsEventName = 'listSubscriptions';
export const ListGroupsEventName = 'listGroups';
export const ListWorkspacesEventName = 'listWorkspaces';
export const RegisterLocalModelEventName = 'registerLocalModel';
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
export const PredictModelEventName = 'predictModel';
export const RegisterModelEventName = 'registerModel';
export const EditModelEventName = 'editModel';
export const UpdateModelEventName = 'updateModel';
export const DeleteModelEventName = 'deleteModel';
export const SourceModelSelectedEventName = 'sourceModelSelected';
export const LoadModelParametersEventName = 'loadModelParameters';
export const StoreImportTableEventName = 'storeImportTable';
export const VerifyImportTableEventName = 'verifyImportTable';
export const SignInToAzureEventName = 'signInToAzure';
/**
* Base class for all model management views
*/
export abstract class ModelViewBase extends ViewBase {
private _modelSourceType: ModelSourceType = ModelSourceType.Local;
private _modelsViewData: ModelViewData[] = [];
private _importTable: DatabaseTable | undefined;
constructor(apiWrapper: ApiWrapper, root?: string, parent?: ModelViewBase) {
super(apiWrapper, root, parent);
}
protected getEventNames(): string[] {
return super.getEventNames().concat([ListModelsEventName,
ListAzureModelsEventName,
ListAccountsEventName,
ListSubscriptionsEventName,
ListGroupsEventName,
ListWorkspacesEventName,
RegisterLocalModelEventName,
RegisterAzureModelEventName,
RegisterModelEventName,
SourceModelSelectedEventName,
ListDatabaseNamesEventName,
ListTableNamesEventName,
ListColumnNamesEventName,
PredictModelEventName,
DownloadAzureModelEventName,
DownloadRegisteredModelEventName,
LoadModelParametersEventName,
StoreImportTableEventName,
VerifyImportTableEventName,
EditModelEventName,
UpdateModelEventName,
DeleteModelEventName,
SignInToAzureEventName]);
}
/**
* Parent view
*/
public get parent(): ModelViewBase | undefined {
return this._parent ? <ModelViewBase>this._parent : undefined;
}
/**
* list azure models
*/
public async listAzureModels(workspaceResource: AzureWorkspaceResource): Promise<WorkspaceModel[]> {
const args: AzureResourceEventArgs = workspaceResource;
return await this.sendDataRequest(ListAzureModelsEventName, args);
}
/**
* list registered models
*/
public async listModels(table: DatabaseTable): Promise<ImportedModel[]> {
return await this.sendDataRequest(ListModelsEventName, table);
}
/**
* lists azure accounts
*/
public async listAzureAccounts(): Promise<azdata.Account[]> {
return await this.sendDataRequest(ListAccountsEventName);
}
/**
* lists database names
*/
public async listDatabaseNames(): Promise<string[]> {
return await this.sendDataRequest(ListDatabaseNamesEventName);
}
/**
* lists table names
*/
public async listTableNames(dbName: string): Promise<DatabaseTable[]> {
return await this.sendDataRequest(ListTableNamesEventName, dbName);
}
/**
* lists column names
*/
public async listColumnNames(table: DatabaseTable): Promise<TableColumn[]> {
return await this.sendDataRequest(ListColumnNamesEventName, table);
}
/**
* lists azure subscriptions
* @param account azure account
*/
public async listAzureSubscriptions(account: azdata.Account | undefined): Promise<azureResource.AzureResourceSubscription[]> {
const args: AzureResourceEventArgs = {
account: account
};
return await this.sendDataRequest(ListSubscriptionsEventName, args);
}
/**
* registers local model
* @param localFilePath local file path
*/
public async importLocalModel(models: ModelViewData[]): Promise<void> {
return await this.sendDataRequest(RegisterLocalModelEventName, models);
}
/**
* downloads registered model
* @param model model to download
*/
public async downloadRegisteredModel(model: ImportedModel | undefined): Promise<string> {
return await this.sendDataRequest(DownloadRegisteredModelEventName, model);
}
/**
* download azure model
* @param args azure resource
*/
public async downloadAzureModel(resource: AzureModelResource | undefined): Promise<string> {
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
}
/**
* Loads model parameters
*/
public async loadModelParameters(): Promise<ModelParameters | undefined> {
return await this.sendDataRequest(LoadModelParametersEventName);
}
/**
* registers azure model
* @param args azure resource
*/
public async importAzureModel(models: ModelViewData[]): Promise<void> {
return await this.sendDataRequest(RegisterAzureModelEventName, models);
}
/**
* Stores the name of the table as recent config table for importing models
*/
public async storeImportConfigTable(): Promise<void> {
await this.sendRequest(StoreImportTableEventName, this.importTable);
}
/**
* Verifies if table is valid to import models to
*/
public async verifyImportConfigTable(table: DatabaseTable): Promise<boolean> {
return await this.sendDataRequest(VerifyImportTableEventName, table);
}
/**
* registers azure model
* @param args azure resource
*/
public async generatePredictScript(model: ImportedModel | undefined, filePath: string | undefined, params: PredictParameters | undefined): Promise<void> {
const args: PredictModelEventArgs = Object.assign({}, params, {
model: model,
filePath: filePath,
loadFromRegisteredModel: !filePath
});
return await this.sendDataRequest(PredictModelEventName, args);
}
/**
* list resource groups
* @param account azure account
* @param subscription azure subscription
*/
public async listAzureGroups(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined): Promise<azureResource.AzureResource[]> {
const args: AzureResourceEventArgs = {
account: account,
subscription: subscription
};
return await this.sendDataRequest(ListGroupsEventName, args);
}
/**
* Sets model source type
*/
public set modelSourceType(value: ModelSourceType) {
if (this.parent) {
this.parent.modelSourceType = value;
} else {
this._modelSourceType = value;
}
}
/**
* Returns model source type
*/
public get modelSourceType(): ModelSourceType {
if (this.parent) {
return this.parent.modelSourceType;
} else {
return this._modelSourceType;
}
}
/**
* Sets model data
*/
public set modelsViewData(value: ModelViewData[]) {
if (this.parent) {
this.parent.modelsViewData = value;
} else {
this._modelsViewData = value;
}
}
/**
* Returns model data
*/
public get modelsViewData(): ModelViewData[] {
if (this.parent) {
return this.parent.modelsViewData;
} else {
return this._modelsViewData;
}
}
/**
* Sets import table
*/
public set importTable(value: DatabaseTable | undefined) {
if (this.parent) {
this.parent.importTable = value;
} else {
this._importTable = value;
}
}
/**
* Returns import table
*/
public get importTable(): DatabaseTable | undefined {
if (this.parent) {
return this.parent.importTable;
} else {
return this._importTable;
}
}
/**
* lists azure workspaces
* @param account azure account
* @param subscription azure subscription
* @param group azure resource group
*/
public async listWorkspaces(account: azdata.Account | undefined, subscription: azureResource.AzureResourceSubscription | undefined, group: azureResource.AzureResource | undefined): Promise<Workspace[]> {
const args: AzureResourceEventArgs = {
account: account,
subscription: subscription,
group: group
};
return await this.sendDataRequest(ListWorkspacesEventName, args);
}
}

View File

@@ -0,0 +1,188 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase, ModelViewData } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
import * as constants from '../../common/constants';
import { IDataComponent } from '../interfaces';
/**
* View to pick local models file
*/
export class ModelsDetailsTableComponent extends ModelViewBase implements IDataComponent<ModelViewData[]> {
private _table: azdata.DeclarativeTableComponent | undefined;
/**
* Creates new view
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Name
displayName: constants.modelFileName,
ariaLabel: constants.modelFileName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.modelName,
ariaLabel: constants.modelName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 150,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Created
displayName: constants.modelDescription,
ariaLabel: constants.modelDescription,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 100,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._table) {
formBuilder.addFormItems([{
title: '',
component: this._table
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._table) {
formBuilder.removeFormItem({
title: '',
component: this._table
});
}
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadData(): Promise<void> {
const models = this.modelsViewData;
if (this._table && models) {
let tableData: any[][] = [];
tableData = tableData.concat(models.map(model => this.createTableRow(model)));
this._table.data = tableData;
}
}
private createTableRow(model: ModelViewData | undefined): any[] {
if (this._modelBuilder && model && model.modelDetails) {
const nameComponent = this._modelBuilder.inputBox().withProperties({
value: model.modelDetails.modelName,
width: this.componentMaxLength - 100,
required: true
}).component();
const descriptionComponent = this._modelBuilder.inputBox().withProperties({
value: model.modelDetails.description,
width: this.componentMaxLength
}).component();
descriptionComponent.onTextChanged(() => {
if (model.modelDetails) {
model.modelDetails.description = descriptionComponent.value;
}
});
nameComponent.onTextChanged(() => {
if (model.modelDetails) {
model.modelDetails.modelName = nameComponent.value || '';
}
});
let deleteButton = this._modelBuilder.button().withProperties({
label: '',
title: constants.deleteTitle,
width: 15,
height: 15,
iconPath: {
dark: this.asAbsolutePath('images/dark/delete_inverse.svg'),
light: this.asAbsolutePath('images/light/delete.svg')
},
}).component();
deleteButton.onDidClick(async () => {
this.modelsViewData = this.modelsViewData.filter(x => x !== model);
await this.refresh();
});
return [model.modelDetails.fileName, nameComponent, descriptionComponent, deleteButton];
}
return [];
}
/**
* Returns selected data
*/
public get data(): ModelViewData[] {
return this.modelsViewData;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._table;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
}

View File

@@ -0,0 +1,101 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { InputColumnsComponent } from './inputColumnsComponent';
import { OutputColumnsComponent } from './outputColumnsComponent';
import { PredictParameters } from '../../../prediction/interfaces';
/**
* View to pick model source
*/
export class ColumnsSelectionPage extends ModelViewBase implements IPageView, IDataComponent<PredictParameters> {
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public inputColumnsComponent: InputColumnsComponent | undefined;
public outputColumnsComponent: OutputColumnsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
*
* @param modelBuilder Register components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.inputColumnsComponent = new InputColumnsComponent(this._apiWrapper, this);
this.inputColumnsComponent.registerComponent(modelBuilder);
this.inputColumnsComponent.addComponents(this._formBuilder);
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
this.outputColumnsComponent.registerComponent(modelBuilder);
this.outputColumnsComponent.addComponents(this._formBuilder);
this._form = this._formBuilder.component();
return this._form;
}
/**
* Returns selected data
*/
public get data(): PredictParameters | undefined {
return this.inputColumnsComponent?.data && this.outputColumnsComponent?.data ?
Object.assign({}, this.inputColumnsComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
undefined;
}
/**
* Returns the component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.inputColumnsComponent) {
await this.inputColumnsComponent.refresh();
}
if (this.outputColumnsComponent) {
await this.outputColumnsComponent.refresh();
}
}
}
public async onEnter(): Promise<void> {
await this.inputColumnsComponent?.onLoading();
await this.outputColumnsComponent?.onLoading();
try {
const modelParameters = await this.loadModelParameters();
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
this.inputColumnsComponent.modelParameters = modelParameters;
this.outputColumnsComponent.modelParameters = modelParameters;
await this.inputColumnsComponent.refresh();
await this.outputColumnsComponent.refresh();
}
} catch (error) {
this.showErrorMessage(constants.loadModelParameterFailedError, error);
}
await this.inputColumnsComponent?.onLoaded();
await this.outputColumnsComponent?.onLoaded();
}
/**
* Returns page title
*/
public get title(): string {
return constants.columnSelectionPageTitle;
}
}

View File

@@ -0,0 +1,302 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IDataComponent } from '../../interfaces';
import { PredictColumn, DatabaseTable, TableColumn } from '../../../prediction/interfaces';
import { ModelParameter, ModelParameters } from '../../../modelManagement/interfaces';
/**
* View to render azure models in a table
*/
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
private _table: azdata.DeclarativeTableComponent | undefined;
private _parameters: PredictColumn[] = [];
private _loader: azdata.LoadingComponent;
private _dataTypes: string[] = [
'bigint',
'int',
'smallint',
'real',
'float',
'varchar(MAX)',
'bit'
];
/**
* Creates a view to render azure models in a table
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _forInput: boolean = true) {
super(apiWrapper, parent.root, parent);
this._loader = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.LoadingComponent {
let columnHeader: azdata.DeclarativeTableColumn[];
if (this._forInput) {
columnHeader = [
{ // Action
displayName: constants.columnName,
ariaLabel: constants.columnName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: '',
ariaLabel: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.inputName,
ariaLabel: constants.inputName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 120,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
} else {
columnHeader = [
{ // Name
displayName: constants.outputName,
ariaLabel: constants.outputName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 200,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.displayName,
ariaLabel: constants.displayName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.dataTypeName,
ariaLabel: constants.dataTypeName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
}
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: columnHeader,
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
this._loader = modelBuilder.loadingComponent()
.withItem(this._table)
.withProperties({
loading: true
}).component();
return this._loader;
}
public async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
public async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
public get component(): azdata.Component {
return this._loader;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadInputs(modelParameters: ModelParameters | undefined, table: DatabaseTable): Promise<void> {
await this.onLoading();
this._parameters = [];
let tableData: any[][] = [];
if (this._table) {
if (this._forInput) {
const columns = await this.listColumnNames(table);
if (modelParameters?.inputs && columns) {
tableData = tableData.concat(modelParameters.inputs.map(input => this.createInputTableRow(input, columns)));
}
}
this._table.data = tableData;
}
await this.onLoaded();
}
public async loadOutputs(modelParameters: ModelParameters | undefined): Promise<void> {
this.onLoading();
this._parameters = [];
let tableData: any[][] = [];
if (this._table) {
if (!this._forInput) {
if (modelParameters?.outputs && this._dataTypes) {
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
}
}
this._table.data = tableData;
}
this.onLoaded();
}
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
if (this._modelBuilder) {
let nameInput = this._modelBuilder.dropDown().withProperties({
values: dataTypes,
width: this.componentMaxLength
}).component();
const name = modelParameter.name;
const dataType = dataTypes.find(x => x === modelParameter.type);
if (dataType) {
nameInput.value = dataType;
}
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
nameInput.onValueChanged(() => {
const value = <string>nameInput.value;
if (value !== modelParameter.type) {
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.dataType = value;
}
}
});
let displayNameInput = this._modelBuilder.inputBox().withProperties({
value: name,
width: 200
}).component();
displayNameInput.onTextChanged(() => {
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.columnName = displayNameInput.value || name;
}
});
return [`${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
}
return [];
}
private createInputTableRow(modelParameter: ModelParameter, columns: TableColumn[] | undefined): any[] {
if (this._modelBuilder && columns) {
const values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; });
let nameInput = this._modelBuilder.dropDown().withProperties({
values: values,
width: this.componentMaxLength
}).component();
const name = modelParameter.name;
let column = values.find(x => x.name === modelParameter.name);
if (!column) {
column = values[0];
}
nameInput.value = column;
this._parameters.push({ columnName: column.name, paramName: name });
nameInput.onValueChanged(() => {
const selectedColumn = nameInput.value;
const value = selectedColumn ? (<azdata.CategoryValue>selectedColumn).name : undefined;
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.columnName = value || '';
}
});
const label = this._modelBuilder.inputBox().withProperties({
value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`,
enabled: false,
width: this.componentMaxLength
}).component();
const image = this._modelBuilder.image().withProperties({
width: 50,
height: 50,
iconPath: {
dark: this.asAbsolutePath('images/arrow.svg'),
light: this.asAbsolutePath('images/arrow.svg')
},
iconWidth: 20,
iconHeight: 20,
title: 'maps'
}).component();
return [nameInput, image, label];
}
return [];
}
/**
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
return this._parameters;
}
/**
* Refreshes the view
*/
public async refresh(): Promise<void> {
}
}

View File

@@ -0,0 +1,142 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
import { ModelParameters } from '../../../modelManagement/interfaces';
import { ColumnsTable } from './columnsTable';
import { TableSelectionComponent } from '../tableSelectionComponent';
/**
* View to render filters to pick an azure resource
*/
export class InputColumnsComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
private _form: azdata.FormContainer | undefined;
private _tableSelectionComponent: TableSelectionComponent | undefined;
private _columns: ColumnsTable | undefined;
private _modelParameters: ModelParameters | undefined;
/**
* Creates a new view
*/
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
super(apiWrapper, parent.root, parent);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false);
this._tableSelectionComponent.registerComponent(modelBuilder);
this._tableSelectionComponent.onSelectedChanged(async () => {
await this.onTableSelected();
});
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this);
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.inputColumns,
component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
formBuilder.addFormItems([{
title: '',
component: this._tableSelectionComponent.component
}, {
title: constants.inputColumns,
component: this._columns.component
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._columns && this._tableSelectionComponent && this._tableSelectionComponent.component) {
formBuilder.removeFormItem({
title: '',
component: this._tableSelectionComponent.component
});
formBuilder.removeFormItem({
title: constants.inputColumns,
component: this._columns.component
});
}
}
/**
* Returns the created component
*/
public get component(): azdata.Component | undefined {
return this._form;
}
/**
* Returns selected data
*/
public get data(): PredictInputParameters | undefined {
return Object.assign({}, this.databaseTable, {
inputColumns: this.columnNames
});
}
/**
* loads data in the components
*/
public async loadData(): Promise<void> {
if (this._tableSelectionComponent) {
this._tableSelectionComponent.refresh();
}
}
public set modelParameters(value: ModelParameters) {
this._modelParameters = value;
}
public async onLoading(): Promise<void> {
if (this._columns) {
await this._columns.onLoading();
}
}
public async onLoaded(): Promise<void> {
if (this._columns) {
await this._columns.onLoaded();
}
}
/**
* refreshes the view
*/
public async refresh(): Promise<void> {
await this.loadData();
}
private async onTableSelected(): Promise<void> {
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
}
private get databaseTable(): DatabaseTable {
let selectedItem = this._tableSelectionComponent?.data;
return {
databaseName: selectedItem?.databaseName,
tableName: selectedItem?.tableName,
schema: selectedItem?.schema
};
}
private get columnNames(): PredictColumn[] | undefined {
return this._columns?.data;
}
}

View File

@@ -0,0 +1,35 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as utils from '../../../common/utils';
/**
* Wizard to register a model
*/
export class ModelArtifact {
/**
* Creates new model artifact
*/
constructor(private _filePath: string, private _deleteAtClose: boolean = true) {
}
public get filePath(): string {
return this._filePath;
}
/**
* Closes the artifact and disposes the resources
*/
public async close(): Promise<void> {
if (this._deleteAtClose) {
try {
await utils.deleteFile(this._filePath);
} catch {
}
}
}
}

Some files were not shown because too many files have changed in this diff Show More