diff --git a/extensions/machine-learning-services/config.json b/extensions/machine-learning-services/config.json index 1ba95ab82e..ab363773d5 100644 --- a/extensions/machine-learning-services/config.json +++ b/extensions/machine-learning-services/config.json @@ -1,13 +1,9 @@ { "sqlPackageManagement": { "requiredPythonPackages": [ - { - "name": "pymssql", - "version": "2.1.4" - }, { "name": "sqlmlutils", - "version": "" + "version": "1.0.0" } ], "requiredRPackages": [ diff --git a/extensions/machine-learning-services/package.json b/extensions/machine-learning-services/package.json index 0ac1dc3946..abe3a7f0a8 100644 --- a/extensions/machine-learning-services/package.json +++ b/extensions/machine-learning-services/package.json @@ -37,7 +37,7 @@ }, "machineLearningServices.enableR": { "type": "boolean", - "default": "true", + "default": "false", "description": "%mls.enableR.description%" }, "machineLearningServices.pythonPath": { diff --git a/extensions/machine-learning-services/src/common/constants.ts b/extensions/machine-learning-services/src/common/constants.ts index 6711463a77..565a111475 100644 --- a/extensions/machine-learning-services/src/common/constants.ts +++ b/extensions/machine-learning-services/src/common/constants.ts @@ -47,7 +47,8 @@ export const msgYes = localize('msgYes', "Yes"); export const msgNo = localize('msgNo', "No"); export const managePackageCommandError = localize('mls.managePackages.error', "Either no connection is available or the server does not have external script enabled."); export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); } -export const installDependenciesMsgTaskName = localize('mls.installDependencies.msgTaskName', "Installing Machine Learning extension dependencies"); +export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Installing package management dependencies"); +export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Installing model management dependencies"); export const noResultError = localize('mls.noResultError', "No Result returned"); export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotInstalled', "The required dependencies are not installed"); export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that."); diff --git a/extensions/machine-learning-services/src/common/queryRunner.ts b/extensions/machine-learning-services/src/common/queryRunner.ts index 90f92a8739..ec46cee7d7 100644 --- a/extensions/machine-learning-services/src/common/queryRunner.ts +++ b/extensions/machine-learning-services/src/common/queryRunner.ts @@ -7,22 +7,32 @@ import * as azdata from 'azdata'; import * as nbExtensionApis from '../typings/notebookServices'; import { ApiWrapper } from './apiWrapper'; import * as constants from '../common/constants'; +import * as utils from '../common/utils'; -const maxNumberOfRetries = 3; +const maxNumberOfRetries = 2; const listPythonPackagesQuery = ` +Declare @tablevar table(name NVARCHAR(MAX), version NVARCHAR(MAX)) +insert into @tablevar(name, version) EXEC sp_execute_external_script @language=N'Python', @script=N'import pkg_resources import pandas OutputDataSet = pandas.DataFrame([(d.project_name, d.version) for d in pkg_resources.working_set])' +select e.name, version from sys.external_libraries e join @tablevar t on e.name = t.name +where [language] = 'PYTHON' `; const listRPackagesQuery = ` +Declare @tablevar table(name NVARCHAR(MAX), version NVARCHAR(MAX)) +insert into @tablevar(name, version) EXEC sp_execute_external_script @language=N'R', @script=N' OutputDataSet <- as.data.frame(installed.packages()[,c(1,3)])' + +select e.name, version from sys.external_libraries e join @tablevar t on e.name = t.name +where [language] = 'R' `; const checkMlInstalledQuery = ` @@ -63,24 +73,24 @@ export class QueryRunner { * Returns python packages installed in SQL server instance * @param connection SQL Connection */ - public async getPythonPackages(connection: azdata.connection.ConnectionProfile): Promise { - return this.getPackages(connection, listPythonPackagesQuery); + public async getPythonPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise { + return this.getPackages(connection, databaseName, listPythonPackagesQuery); } /** * Returns python packages installed in SQL server instance * @param connection SQL Connection */ - public async getRPackages(connection: azdata.connection.ConnectionProfile): Promise { - return this.getPackages(connection, listRPackagesQuery); + public async getRPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise { + return this.getPackages(connection, databaseName, listRPackagesQuery); } - private async getPackages(connection: azdata.connection.ConnectionProfile, script: string): Promise { + private async getPackages(connection: azdata.connection.ConnectionProfile, databaseName: string, script: string): Promise { let packages: nbExtensionApis.IPackageDetails[] = []; let result: azdata.SimpleExecuteResult | undefined = undefined; for (let index = 0; index < maxNumberOfRetries; index++) { - result = await this.runQuery(connection, script); + result = await this.runQuery(connection, utils.getScriptWithDBChange(connection.databaseName, databaseName, script)); if (result && result.rowCount > 0) { break; } diff --git a/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts b/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts index 1b5022b554..1efe1780cb 100644 --- a/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts +++ b/extensions/machine-learning-services/src/modelManagement/modelPythonClient.ts @@ -40,7 +40,7 @@ export class ModelPythonClient { * Installs dependencies for python client */ private async installDependencies(): Promise { - await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [ + await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [ this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true); } diff --git a/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts b/extensions/machine-learning-services/src/packageManagement/packageManageProviderBase.ts similarity index 76% rename from extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts rename to extensions/machine-learning-services/src/packageManagement/packageManageProviderBase.ts index 90870aa84a..36549d9523 100644 --- a/extensions/machine-learning-services/src/packageManagement/SqlPackageManageProviderBase.ts +++ b/extensions/machine-learning-services/src/packageManagement/packageManageProviderBase.ts @@ -5,7 +5,6 @@ import * as azdata from 'azdata'; import { ApiWrapper } from '../common/apiWrapper'; -import * as constants from '../common/constants'; import * as nbExtensionApis from '../typings/notebookServices'; import * as utils from '../common/utils'; @@ -23,14 +22,17 @@ export abstract class SqlPackageManageProviderBase { } /** - * Returns location title + * Returns database names */ - public async getLocationTitle(): Promise { + public async getLocations(): Promise { let connection = await this.getCurrentConnection(); if (connection) { - return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`; + let databases = await this._apiWrapper.listDatabases(connection.connectionId); + return databases.map(x => { + return { displayName: x, name: x }; + }); } - return constants.noConnectionError; + return []; } protected async getCurrentConnection(): Promise { @@ -42,16 +44,16 @@ export abstract class SqlPackageManageProviderBase { * @param packages Packages to install * @param useMinVersion minimum version */ - public async installPackages(packages: nbExtensionApis.IPackageDetails[], useMinVersion: boolean): Promise { + public async installPackages(packages: nbExtensionApis.IPackageDetails[], useMinVersion: boolean, databaseName: string): Promise { if (packages) { - await Promise.all(packages.map(x => this.installPackage(x, useMinVersion))); + await Promise.all(packages.map(x => this.installPackage(x, useMinVersion, databaseName))); } //TODO: use useMinVersion console.log(useMinVersion); } - private async installPackage(packageDetail: nbExtensionApis.IPackageDetails, useMinVersion: boolean): Promise { + private async installPackage(packageDetail: nbExtensionApis.IPackageDetails, useMinVersion: boolean, databaseName: string): Promise { if (useMinVersion) { let packageOverview = await this.getPackageOverview(packageDetail.name); if (packageOverview && packageOverview.versions) { @@ -60,16 +62,16 @@ export abstract class SqlPackageManageProviderBase { } } - await this.executeScripts(ScriptMode.Install, packageDetail); + await this.executeScripts(ScriptMode.Install, packageDetail, databaseName); } /** * Uninstalls given packages * @param packages Packages to uninstall */ - public async uninstallPackages(packages: nbExtensionApis.IPackageDetails[]): Promise { + public async uninstallPackages(packages: nbExtensionApis.IPackageDetails[], databaseName: string): Promise { if (packages) { - await Promise.all(packages.map(x => this.executeScripts(ScriptMode.Uninstall, x))); + await Promise.all(packages.map(x => this.executeScripts(ScriptMode.Uninstall, x, databaseName))); } } @@ -88,8 +90,8 @@ export abstract class SqlPackageManageProviderBase { /** * Returns list of packages */ - public async listPackages(): Promise { - let packages = await this.fetchPackages(); + public async listPackages(databaseName: string): Promise { + let packages = await this.fetchPackages(databaseName); if (packages) { packages = packages.sort((a, b) => this.comparePackages(a, b)); } else { @@ -110,6 +112,6 @@ export abstract class SqlPackageManageProviderBase { } protected abstract fetchPackage(packageName: string): Promise; - protected abstract fetchPackages(): Promise; - protected abstract executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails): Promise; + protected abstract fetchPackages(databaseName: string): Promise; + protected abstract executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise; } diff --git a/extensions/machine-learning-services/src/packageManagement/packageManagementService.ts b/extensions/machine-learning-services/src/packageManagement/packageManagementService.ts index b2cf31381d..fcb73a6555 100644 --- a/extensions/machine-learning-services/src/packageManagement/packageManagementService.ts +++ b/extensions/machine-learning-services/src/packageManagement/packageManagementService.ts @@ -103,15 +103,15 @@ export class PackageManagementService { * Returns python packages installed in SQL server instance * @param connection SQL Connection */ - public async getPythonPackages(connection: azdata.connection.ConnectionProfile): Promise { - return this._queryRunner.getPythonPackages(connection); + public async getPythonPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise { + return this._queryRunner.getPythonPackages(connection, databaseName); } /** * Returns python packages installed in SQL server instance * @param connection SQL Connection */ - public async getRPackages(connection: azdata.connection.ConnectionProfile): Promise { - return this._queryRunner.getRPackages(connection); + public async getRPackages(connection: azdata.connection.ConnectionProfile, databaseName: string): Promise { + return this._queryRunner.getRPackages(connection, databaseName); } } diff --git a/extensions/machine-learning-services/src/packageManagement/packageManager.ts b/extensions/machine-learning-services/src/packageManagement/packageManager.ts index 72e2bdb2f6..ec45279b29 100644 --- a/extensions/machine-learning-services/src/packageManagement/packageManager.ts +++ b/extensions/machine-learning-services/src/packageManagement/packageManager.ts @@ -93,7 +93,6 @@ export class PackageManager { // Execute the command // this._apiWrapper.executeCommand(constants.managePackagesCommand, { - multiLocations: false, defaultLocation: defaultProvider.packageTarget.location, defaultProviderId: defaultProvider.providerId }); @@ -116,7 +115,7 @@ export class PackageManager { * Installs dependencies for the extension */ public async installDependencies(): Promise { - await utils.executeTasks(this._apiWrapper, constants.installDependenciesMsgTaskName, [ + await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [ this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages), this.installRequiredRPackages()], true); } @@ -130,7 +129,7 @@ export class PackageManager { } await utils.createFolder(utils.getRPackagesFolderPath(this._rootFolder)); - await Promise.all(this._config.requiredSqlPythonPackages.map(x => this.installRPackage(x))); + await Promise.all(this._config.requiredSqlRPackages.map(x => this.installRPackage(x))); } /** @@ -151,7 +150,8 @@ export class PackageManager { let fileContent = ''; requiredPackages.forEach(packageDetails => { let hasVersion = ('version' in packageDetails) && !isNullOrUndefined(packageDetails['version']) && packageDetails['version'].length > 0; - if (!installedPackages.find(x => x.name === packageDetails['name'] && (!hasVersion || packageDetails['version'] === x.version))) { + if (!installedPackages.find(x => x.name === packageDetails['name'] + && (!hasVersion || utils.comparePackageVersions(packageDetails['version'] || '', x.version) <= 0))) { let packageNameDetail = hasVersion ? `${packageDetails.name}==${packageDetails.version}` : `${packageDetails.name}`; fileContent = `${fileContent}${packageNameDetail}\n`; } @@ -177,7 +177,7 @@ export class PackageManager { private async getInstalledPipPackages(): Promise { try { let cmd = `"${this.pythonExecutable}" -m pip list --format=json`; - let packagesInfo = await this._processService.executeBufferedCommand(cmd, this._outputChannel); + let packagesInfo = await this._processService.executeBufferedCommand(cmd, undefined); let packagesResult: nbExtensionApis.IPackageDetails[] = []; if (packagesInfo) { packagesResult = JSON.parse(packagesInfo); diff --git a/extensions/machine-learning-services/src/packageManagement/sqlPythonPackageManageProvider.ts b/extensions/machine-learning-services/src/packageManagement/sqlPythonPackageManageProvider.ts index 8ebf25e472..8639d701ab 100644 --- a/extensions/machine-learning-services/src/packageManagement/sqlPythonPackageManageProvider.ts +++ b/extensions/machine-learning-services/src/packageManagement/sqlPythonPackageManageProvider.ts @@ -9,7 +9,7 @@ import * as nbExtensionApis from '../typings/notebookServices'; import { ApiWrapper } from '../common/apiWrapper'; import { ProcessService } from '../common/processService'; import { Config } from '../configurations/config'; -import { SqlPackageManageProviderBase, ScriptMode } from './SqlPackageManageProviderBase'; +import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProviderBase'; import { HttpClient } from '../common/httpClient'; import * as utils from '../common/utils'; import { PackageManagementService } from './packageManagementService'; @@ -50,8 +50,8 @@ export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase /** * Returns list of packages */ - protected async fetchPackages(): Promise { - return await this._service.getPythonPackages(await this.getCurrentConnection()); + protected async fetchPackages(databaseName: string): Promise { + return await this._service.getPythonPackages(await this.getCurrentConnection(), databaseName); } /** @@ -59,14 +59,14 @@ export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase * @param packageDetails Packages to install or uninstall * @param scriptMode can be 'install' or 'uninstall' */ - protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails): Promise { + protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise { let connection = await this.getCurrentConnection(); let credentials = await this._apiWrapper.getCredentials(connection.connectionId); if (connection) { let port = '1433'; let server = connection.serverName; - let database = connection.databaseName ? `, database="${connection.databaseName}"` : ''; + let database = databaseName ? `, database="${databaseName}"` : ''; let index = connection.serverName.indexOf(','); if (index > 0) { port = connection.serverName.substring(index + 1); diff --git a/extensions/machine-learning-services/src/packageManagement/sqlRPackageManageProvider.ts b/extensions/machine-learning-services/src/packageManagement/sqlRPackageManageProvider.ts index a947ce1e66..c40560e4b8 100644 --- a/extensions/machine-learning-services/src/packageManagement/sqlRPackageManageProvider.ts +++ b/extensions/machine-learning-services/src/packageManagement/sqlRPackageManageProvider.ts @@ -10,7 +10,7 @@ import * as nbExtensionApis from '../typings/notebookServices'; import { ApiWrapper } from '../common/apiWrapper'; import { ProcessService } from '../common/processService'; import { Config } from '../configurations/config'; -import { SqlPackageManageProviderBase, ScriptMode } from './SqlPackageManageProviderBase'; +import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProviderBase'; import { HttpClient } from '../common/httpClient'; import * as constants from '../common/constants'; import { PackageManagementService } from './packageManagementService'; @@ -54,8 +54,8 @@ export class SqlRPackageManageProvider extends SqlPackageManageProviderBase impl /** * Returns list of packages */ - protected async fetchPackages(): Promise { - return await this._service.getRPackages(await this.getCurrentConnection()); + protected async fetchPackages(databaseName: string): Promise { + return await this._service.getRPackages(await this.getCurrentConnection(), databaseName); } /** @@ -63,12 +63,12 @@ export class SqlRPackageManageProvider extends SqlPackageManageProviderBase impl * @param packageDetails Packages to install or uninstall * @param scriptMode can be 'install' or 'uninstall' */ - protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails): Promise { + protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise { let connection = await this.getCurrentConnection(); let credentials = await this._apiWrapper.getCredentials(connection.connectionId); if (connection) { - let database = connection.databaseName ? `, database="${connection.databaseName}"` : ''; + let database = databaseName ? `, database="${databaseName}"` : ''; let connectionParts = `server="${connection.serverName}", uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"${database}`; let rCommandScript = scriptMode === ScriptMode.Install ? 'sql_install.packages' : 'sql_remove.packages'; diff --git a/extensions/machine-learning-services/src/test/packageManagement/packageManager.test.ts b/extensions/machine-learning-services/src/test/packageManagement/packageManager.test.ts index 950a2f60be..ca85657c92 100644 --- a/extensions/machine-learning-services/src/test/packageManagement/packageManager.test.ts +++ b/extensions/machine-learning-services/src/test/packageManagement/packageManager.test.ts @@ -116,7 +116,7 @@ describe('Package Manager', () => { it('installDependencies Should install packages that are not already installed', async function (): Promise { let testContext = createContext(); - //let packagesInstalled = false; + let packagesInstalled = false; let installedPackages = `[ {"name":"pymssql","version":"2.1.4"} ]`; @@ -128,15 +128,67 @@ describe('Package Manager', () => { }); testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => { if (command.indexOf('pip install') > 0) { - //packagesInstalled = true; + packagesInstalled = true; } return Promise.resolve(installedPackages); }); let packageManager = createPackageManager(testContext); await packageManager.installDependencies(); - //should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded); - //should.equal(packagesInstalled, true); + should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded); + should.equal(packagesInstalled, true); + }); + + it('installDependencies Should not install packages if runtime is disabled in setting', async function (): Promise { + let testContext = createContext(); + testContext.config.setup(x => x.rEnabled).returns(() => false); + testContext.config.setup(x => x.pythonEnabled).returns(() => false); + let packagesInstalled = false; + let installedPackages = `[ + {"name":"pymssql","version":"2.1.4"} + ]`; + testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ + label: 'Yes' + })); + testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { + operationInfo.operation(testContext.op); + }); + testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => { + if (command.indexOf('pip install') > 0 || command.indexOf('install.packages') > 0) { + packagesInstalled = true; + } + return Promise.resolve(installedPackages); + }); + + let packageManager = createPackageManager(testContext); + await packageManager.installDependencies(); + should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded); + should.equal(packagesInstalled, false); + }); + + it('installDependencies Should install packages that have older version installed', async function (): Promise { + let testContext = createContext(); + let packagesInstalled = false; + let installedPackages = `[ + {"name":"sqlmlutils","version":"0.1.1"} + ]`; + testContext.apiWrapper.setup(x => x.showQuickPick(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve({ + label: 'Yes' + })); + testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { + operationInfo.operation(testContext.op); + }); + testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => { + if (command.indexOf('pip install') > 0) { + packagesInstalled = true; + } + return Promise.resolve(installedPackages); + }); + + let packageManager = createPackageManager(testContext); + await packageManager.installDependencies(); + should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded); + should.equal(packagesInstalled, true); }); it('installDependencies Should install packages if list packages fails', async function (): Promise { @@ -197,7 +249,7 @@ describe('Package Manager', () => { { name: 'pymssql', version: '2.1.4' }, { name: 'sqlmlutils', version: '' } ]); - testContext.config.setup(x => x.requiredSqlPythonPackages).returns( () => [ + testContext.config.setup(x => x.requiredSqlRPackages).returns( () => [ { name: 'RODBCext', repository: 'https://cran.microsoft.com' }, { name: 'sqlmlutils', fileName: 'sqlmlutils_0.7.1.zip', downloadUrl: 'https://github.com/microsoft/sqlmlutils/blob/master/R/dist/sqlmlutils_0.7.1.zip?raw=true'} ]); diff --git a/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts b/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts index 0ee3592e29..ceb2f50202 100644 --- a/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts +++ b/extensions/machine-learning-services/src/test/packageManagement/sqlPythonPackageManageProvider.test.ts @@ -7,7 +7,6 @@ import * as azdata from 'azdata'; import * as should from 'should'; import 'mocha'; import * as TypeMoq from 'typemoq'; -import * as constants from '../../common/constants'; import { SqlPythonPackageManageProvider } from '../../packageManagement/sqlPythonPackageManageProvider'; import { createContext, TestContext } from './utils'; import * as nbExtensionApis from '../../typings/notebookServices'; @@ -40,10 +39,10 @@ describe('SQL Python Package Manager', () => { let connection = new azdata.connection.ConnectionProfile(); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); + testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); let provider = createProvider(testContext); - let actual = await provider.listPackages(); + let actual = await provider.listPackages(connection.databaseName); let expected = [ { 'name': 'a-name', @@ -72,10 +71,10 @@ describe('SQL Python Package Manager', () => { let connection = new azdata.connection.ConnectionProfile(); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); + testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); let provider = createProvider(testContext); - let actual = await provider.listPackages(); + let actual = await provider.listPackages(connection.databaseName); let expected = [ { 'name': 'b-name', @@ -95,10 +94,10 @@ describe('SQL Python Package Manager', () => { let connection = new azdata.connection.ConnectionProfile(); let packages: nbExtensionApis.IPackageDetails[]; testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); + testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); let provider = createProvider(testContext); - let actual = await provider.listPackages(); + let actual = await provider.listPackages(connection.databaseName); let expected: nbExtensionApis.IPackageDetails[] = []; should.deepEqual(actual, expected); }); @@ -108,10 +107,10 @@ describe('SQL Python Package Manager', () => { let connection = new azdata.connection.ConnectionProfile(); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve([])); + testContext.serverConfigManager.setup(x => x.getPythonPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve([])); let provider = createProvider(testContext); - let actual = await provider.listPackages(); + let actual = await provider.listPackages(connection.databaseName); let expected: nbExtensionApis.IPackageDetails[] = []; should.deepEqual(actual, expected); }); @@ -152,7 +151,7 @@ describe('SQL Python Package Manager', () => { }); let provider = createProvider(testContext); - await provider.installPackages(packages, false); + await provider.installPackages(packages, false, connection.databaseName); should.deepEqual(packagesUpdated, true); }); @@ -192,7 +191,7 @@ describe('SQL Python Package Manager', () => { }); let provider = createProvider(testContext); - await provider.uninstallPackages(packages); + await provider.uninstallPackages(packages, connection.databaseName); should.deepEqual(packagesUpdated, true); }); @@ -233,7 +232,7 @@ describe('SQL Python Package Manager', () => { }); let provider = createProvider(testContext); - await provider.installPackages(packages, false); + await provider.installPackages(packages, false, connection.databaseName); should.deepEqual(packagesUpdated, true); }); @@ -255,7 +254,7 @@ describe('SQL Python Package Manager', () => { let provider = createProvider(testContext); - await provider.installPackages(packages, false); + await provider.installPackages(packages, false, connection.databaseName); should.deepEqual(packagesUpdated, false); }); @@ -277,7 +276,7 @@ describe('SQL Python Package Manager', () => { let provider = createProvider(testContext); - await provider.uninstallPackages(packages); + await provider.uninstallPackages(packages, connection.databaseName); should.deepEqual(packagesUpdated, false); }); @@ -346,42 +345,44 @@ describe('SQL Python Package Manager', () => { should.deepEqual(actual, packagePreview); }); - it('getLocationTitle Should default string for no connection', async function (): Promise { + it('getLocations Should return empty array for no connection', async function (): Promise { let testContext = createContext(); let connection: azdata.connection.ConnectionProfile; testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); let provider = createProvider(testContext); - let actual = await provider.getLocationTitle(); + let actual = await provider.getLocations(); - should.deepEqual(actual, constants.noConnectionError); + should.deepEqual(actual, []); }); - it('getLocationTitle Should return connection title string for valid connection', async function (): Promise { + it('getLocations Should return database names for valid connection', async function (): Promise { let testContext = createContext(); let connection = new azdata.connection.ConnectionProfile(); connection.serverName = 'serverName'; connection.databaseName = 'databaseName'; + const databaseNames = [ + 'db1', + 'db2' + ]; + const expected = [ + { + displayName: 'db1', + name: 'db1' + }, + { + displayName: 'db2', + name: 'db2' + } + ]; testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + testContext.apiWrapper.setup(x => x.listDatabases(connection.connectionId)).returns(() => { return Promise.resolve(databaseNames); }); let provider = createProvider(testContext); - let actual = await provider.getLocationTitle(); + let actual = await provider.getLocations(); - should.deepEqual(actual, `${connection.serverName} ${connection.databaseName}`); - }); - - it('getLocationTitle Should return server name as connection title if there is not database name', async function (): Promise { - let testContext = createContext(); - - let connection = new azdata.connection.ConnectionProfile(); - connection.serverName = 'serverName'; - testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - - let provider = createProvider(testContext); - let actual = await provider.getLocationTitle(); - - should.deepEqual(actual, `${connection.serverName} `); + should.deepEqual(actual, expected); }); function createProvider(testContext: TestContext): SqlPythonPackageManageProvider { diff --git a/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts b/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts index 6752a00dc1..4bf4146785 100644 --- a/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts +++ b/extensions/machine-learning-services/src/test/packageManagement/sqlRPackageManageProvider.test.ts @@ -7,7 +7,6 @@ import * as azdata from 'azdata'; import * as should from 'should'; import 'mocha'; import * as TypeMoq from 'typemoq'; -import * as constants from '../../common/constants'; import { SqlRPackageManageProvider } from '../../packageManagement/sqlRPackageManageProvider'; import { createContext, TestContext } from './utils'; import * as nbExtensionApis from '../../typings/notebookServices'; @@ -40,10 +39,10 @@ describe('SQL R Package Manager', () => { let connection = new azdata.connection.ConnectionProfile(); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); + testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); let provider = createProvider(testContext); - let actual = await provider.listPackages(); + let actual = await provider.listPackages(connection.databaseName); let expected = [ { 'name': 'a-name', @@ -63,10 +62,10 @@ describe('SQL R Package Manager', () => { let connection = new azdata.connection.ConnectionProfile(); let packages: nbExtensionApis.IPackageDetails[]; testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); + testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(packages)); let provider = createProvider(testContext); - let actual = await provider.listPackages(); + let actual = await provider.listPackages(connection.databaseName); let expected: nbExtensionApis.IPackageDetails[] = []; should.deepEqual(actual, expected); }); @@ -76,10 +75,10 @@ describe('SQL R Package Manager', () => { let connection = new azdata.connection.ConnectionProfile(); testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve([])); + testContext.serverConfigManager.setup(x => x.getRPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve([])); let provider = createProvider(testContext); - let actual = await provider.listPackages(); + let actual = await provider.listPackages(connection.databaseName); let expected: nbExtensionApis.IPackageDetails[] = []; should.deepEqual(actual, expected); }); @@ -118,7 +117,7 @@ describe('SQL R Package Manager', () => { }); let provider = createProvider(testContext); - await provider.installPackages(packages, false); + await provider.installPackages(packages, false, connection.databaseName); should.deepEqual(packagesUpdated, true); }); @@ -157,7 +156,7 @@ describe('SQL R Package Manager', () => { }); let provider = createProvider(testContext); - await provider.uninstallPackages(packages); + await provider.uninstallPackages(packages, connection.databaseName); should.deepEqual(packagesUpdated, true); }); @@ -179,7 +178,7 @@ describe('SQL R Package Manager', () => { let provider = createProvider(testContext); - await provider.installPackages(packages, false); + await provider.installPackages(packages, false, connection.databaseName); should.deepEqual(packagesUpdated, false); }); @@ -201,7 +200,7 @@ describe('SQL R Package Manager', () => { let provider = createProvider(testContext); - await provider.uninstallPackages(packages); + await provider.uninstallPackages(packages, connection.databaseName); should.deepEqual(packagesUpdated, false); }); @@ -271,42 +270,44 @@ describe('SQL R Package Manager', () => { should.deepEqual(actual, packagePreview); }); - it('getLocationTitle Should default string for no connection', async function (): Promise { + it('getLocations Should return empty array for no connection', async function (): Promise { let testContext = createContext(); let connection: azdata.connection.ConnectionProfile; testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); let provider = createProvider(testContext); - let actual = await provider.getLocationTitle(); + let actual = await provider.getLocations(); - should.deepEqual(actual, constants.noConnectionError); + should.deepEqual(actual, []); }); - it('getLocationTitle Should return connection title string for valid connection', async function (): Promise { + it('getLocations Should return database names for valid connection', async function (): Promise { let testContext = createContext(); let connection = new azdata.connection.ConnectionProfile(); connection.serverName = 'serverName'; connection.databaseName = 'databaseName'; + const databaseNames = [ + 'db1', + 'db2' + ]; + const expected = [ + { + displayName: 'db1', + name: 'db1' + }, + { + displayName: 'db2', + name: 'db2' + } + ]; testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + testContext.apiWrapper.setup(x => x.listDatabases(connection.connectionId)).returns(() => { return Promise.resolve(databaseNames); }); let provider = createProvider(testContext); - let actual = await provider.getLocationTitle(); + let actual = await provider.getLocations(); - should.deepEqual(actual, `${connection.serverName} ${connection.databaseName}`); - }); - - it('getLocationTitle Should return server name as connection title if there is not database name', async function (): Promise { - let testContext = createContext(); - - let connection = new azdata.connection.ConnectionProfile(); - connection.serverName = 'serverName'; - testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); - - let provider = createProvider(testContext); - let actual = await provider.getLocationTitle(); - - should.deepEqual(actual, `${connection.serverName} `); + should.deepEqual(actual, expected); }); function createProvider(testContext: TestContext): SqlRPackageManageProvider { diff --git a/extensions/machine-learning-services/src/test/queryRunner.test.ts b/extensions/machine-learning-services/src/test/queryRunner.test.ts index 6183ae593e..7786da774f 100644 --- a/extensions/machine-learning-services/src/test/queryRunner.test.ts +++ b/extensions/machine-learning-services/src/test/queryRunner.test.ts @@ -59,7 +59,7 @@ describe('Query Runner', () => { let queryProvider: azdata.QueryProvider; testContext.apiWrapper.setup(x => x.getProvider(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => queryProvider); - let actual = await queryRunner.getPythonPackages(connection); + let actual = await queryRunner.getPythonPackages(connection, connection.databaseName); should.deepEqual(actual, []); }); @@ -70,7 +70,7 @@ describe('Query Runner', () => { testContext.queryProvider.runQueryAndReturn = () => { return Promise.reject(); }; testContext.apiWrapper.setup(x => x.getProvider(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider); - let actual = await queryRunner.getPythonPackages(connection); + let actual = await queryRunner.getPythonPackages(connection, connection.databaseName); should.deepEqual(actual, []); }); @@ -117,7 +117,7 @@ describe('Query Runner', () => { testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); }; testContext.apiWrapper.setup(x => x.getProvider(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider); - let actual = await queryRunner.getPythonPackages(connection); + let actual = await queryRunner.getPythonPackages(connection, connection.databaseName); should.deepEqual(actual, expected); }); @@ -138,7 +138,7 @@ describe('Query Runner', () => { testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); }; testContext.apiWrapper.setup(x => x.getProvider(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider); - let actual = await queryRunner.getPythonPackages(connection); + let actual = await queryRunner.getPythonPackages(connection, connection.databaseName); should.deepEqual(actual, expected); }); diff --git a/extensions/machine-learning-services/src/typings/notebookServices.d.ts b/extensions/machine-learning-services/src/typings/notebookServices.d.ts index 74ad82a49d..eb643b105e 100644 --- a/extensions/machine-learning-services/src/typings/notebookServices.d.ts +++ b/extensions/machine-learning-services/src/typings/notebookServices.d.ts @@ -51,13 +51,56 @@ export interface IPackageOverview { summary: string; } -export interface IPackageManageProvider { - providerId: string; - packageTarget: IPackageTarget; - listPackages(): Promise - installPackages(package: IPackageDetails[], useMinVersion: boolean): Promise; - uninstallPackages(package: IPackageDetails[]): Promise; - canUseProvider(): Promise; - getLocationTitle(): Promise; - getPackageOverview(packageName: string): Promise +export interface IPackageLocation { + name: string; + displayName: string; +} + +/** + * Package manage provider interface + */ +export interface IPackageManageProvider { + /** + * Provider id + */ + providerId: string; + + /** + * package target + */ + packageTarget: IPackageTarget; + + /** + * Returns list of installed packages + */ + listPackages(location?: string): Promise; + + /** + * Installs give packages + * @param package Packages to install + * @param useMinVersion if true, minimal version will be used + */ + installPackages(package: IPackageDetails[], useMinVersion: boolean, location?: string): Promise; + + /** + * Uninstalls given packages + * @param package package to uninstall + */ + uninstallPackages(package: IPackageDetails[], location?: string): Promise; + + /** + * Returns true if the provider can be used in current context + */ + canUseProvider(): Promise; + + /** + * Returns location title + */ + getLocations(): Promise; + + /** + * Returns Package Overview + * @param packageName package name + */ + getPackageOverview(packageName: string): Promise; } diff --git a/extensions/notebook/src/dialog/managePackages/addNewPackageTab.ts b/extensions/notebook/src/dialog/managePackages/addNewPackageTab.ts index d0223f4f6f..c111e9b27e 100644 --- a/extensions/notebook/src/dialog/managePackages/addNewPackageTab.ts +++ b/extensions/notebook/src/dialog/managePackages/addNewPackageTab.ts @@ -147,7 +147,7 @@ export class AddNewPackageTab { let pipPackage: PipPackageOverview; pipPackage = await this.dialog.model.getPackageOverview(packageName); - if (!pipPackage.versions || pipPackage.versions.length === 0) { + if (!pipPackage?.versions || pipPackage.versions.length === 0) { this.dialog.showErrorMessage( localize('managePackages.noVersionsFound', "Could not find any valid versions for the specified package")); diff --git a/extensions/notebook/src/dialog/managePackages/installedPackagesTab.ts b/extensions/notebook/src/dialog/managePackages/installedPackagesTab.ts index 60302554d8..c9962e625f 100644 --- a/extensions/notebook/src/dialog/managePackages/installedPackagesTab.ts +++ b/extensions/notebook/src/dialog/managePackages/installedPackagesTab.ts @@ -20,11 +20,13 @@ export class InstalledPackagesTab { private installedPkgTab: azdata.window.DialogTab; private packageTypeDropdown: azdata.DropDownComponent; - private locationComponent: azdata.TextComponent; + private locationComponent: azdata.Component; private installedPackageCount: azdata.TextComponent; private installedPackagesTable: azdata.TableComponent; private installedPackagesLoader: azdata.LoadingComponent; private uninstallPackageButton: azdata.ButtonComponent; + private view: azdata.ModelView | undefined; + private formBuilder: azdata.FormBuilder; constructor(private dialog: ManagePackagesDialog, private jupyterInstallation: JupyterServerInstallation) { this.prompter = new CodeAdapter(); @@ -32,14 +34,7 @@ export class InstalledPackagesTab { this.installedPkgTab = azdata.window.createTab(localize('managePackages.installedTabTitle', "Installed")); this.installedPkgTab.registerContent(async view => { - - // TODO: only supporting single location for now. We should add a drop down for multi locations mode - // - let locationTitle = await this.dialog.model.getLocationTitle(); - this.locationComponent = view.modelBuilder.text().withProperties({ - value: locationTitle - }).component(); - + this.view = view; let dropdownValues = this.dialog.model.getPackageTypes().map(x => { return { name: x.providerId, @@ -52,11 +47,17 @@ export class InstalledPackagesTab { value: defaultPackageType }).component(); this.dialog.changeProvider(defaultPackageType.providerId); - this.packageTypeDropdown.onValueChanged(() => { - this.dialog.resetPages((this.packageTypeDropdown.value).name) - .catch(err => { - this.dialog.showErrorMessage(utils.getErrorMessage(err)); - }); + this.packageTypeDropdown.onValueChanged(async () => { + this.dialog.changeProvider((this.packageTypeDropdown.value).name); + try { + await this.resetLocations(); + await this.dialog.resetPages(); + } + catch (err) { + this.dialog.showErrorMessage(utils.getErrorMessage(err)); + + } + }); this.installedPackageCount = view.modelBuilder.text().withProperties({ @@ -81,11 +82,8 @@ export class InstalledPackagesTab { }).component(); this.uninstallPackageButton.onDidClick(() => this.doUninstallPackage()); - let formModel = view.modelBuilder.formContainer() + this.formBuilder = view.modelBuilder.formContainer() .withFormItems([{ - component: this.locationComponent, - title: localize('managePackages.location', "Location") - }, { component: this.packageTypeDropdown, title: localize('managePackages.packageType', "Package Type") }, { @@ -97,10 +95,11 @@ export class InstalledPackagesTab { }, { component: this.uninstallPackageButton, title: '' - }]).component(); + }]); + await this.resetLocations(); this.installedPackagesLoader = view.modelBuilder.loadingComponent() - .withItem(formModel) + .withItem(this.formBuilder.component()) .withProperties({ loading: true }).component(); @@ -112,6 +111,68 @@ export class InstalledPackagesTab { }); } + private async resetLocations(): Promise { + if (this.view) { + if (this.locationComponent) { + this.formBuilder.removeFormItem({ + component: this.locationComponent, + title: localize('managePackages.location', "Location") + }); + } + + this.locationComponent = await InstalledPackagesTab.getLocationComponent(this.view, this.dialog); + + this.formBuilder.insertFormItem({ + component: this.locationComponent, + title: localize('managePackages.location', "Location") + }, 1); + } + } + + /** + * Creates a component for package locations + * @param view Model view + * @param dialog Manage package dialog + */ + public static async getLocationComponent(view: azdata.ModelView, dialog: ManagePackagesDialog): Promise { + const locations = await dialog.model.getLocations(); + let component: azdata.Component; + if (locations && locations.length === 1) { + component = view.modelBuilder.text().withProperties({ + value: locations[0].displayName + }).component(); + } else if (locations) { + let dropdownValues = locations.map(x => { + return { + name: x.name, + displayName: x.displayName + }; + }); + let locationDropDown = view.modelBuilder.dropDown().withProperties({ + values: dropdownValues, + value: dropdownValues[0] + }).component(); + + locationDropDown.onValueChanged(async () => { + dialog.changeLocation((locationDropDown.value).name); + try { + await dialog.resetPages(); + } + catch (err) { + dialog.showErrorMessage(utils.getErrorMessage(err)); + } + }); + component = locationDropDown; + } else { + component = view.modelBuilder.text().withProperties({ + }).component(); + } + if (locations && locations.length > 0) { + dialog.changeLocation(locations[0].name); + } + return component; + } + public get tab(): azdata.window.DialogTab { return this.installedPkgTab; } diff --git a/extensions/notebook/src/dialog/managePackages/managePackagesDialog.ts b/extensions/notebook/src/dialog/managePackages/managePackagesDialog.ts index 644888c61d..e3bcdcc69a 100644 --- a/extensions/notebook/src/dialog/managePackages/managePackagesDialog.ts +++ b/extensions/notebook/src/dialog/managePackages/managePackagesDialog.ts @@ -67,14 +67,17 @@ export class ManagePackagesDialog { } /** - * Resets the tabs for given provider Id - * @param providerId Package Management Provider Id + * Changes the current location + * @param location location name */ - public async resetPages(providerId: string): Promise { + public changeLocation(location: string): void { + this.model.changeLocation(location); + } - // Change the provider in the model - // - this.changeProvider(providerId); + /** + * Resets the tabs for given provider Id + */ + public async resetPages(): Promise { // Load packages for given provider // diff --git a/extensions/notebook/src/dialog/managePackages/managePackagesDialogModel.ts b/extensions/notebook/src/dialog/managePackages/managePackagesDialogModel.ts index 66bee24037..4300f457d1 100644 --- a/extensions/notebook/src/dialog/managePackages/managePackagesDialogModel.ts +++ b/extensions/notebook/src/dialog/managePackages/managePackagesDialogModel.ts @@ -4,10 +4,9 @@ *--------------------------------------------------------------------------------------------*/ import { JupyterServerInstallation } from '../../jupyter/jupyterServerInstallation'; -import { IPackageManageProvider, IPackageDetails, IPackageOverview } from '../../types'; +import { IPackageManageProvider, IPackageDetails, IPackageOverview, IPackageLocation } from '../../types'; export interface ManagePackageDialogOptions { - multiLocations: boolean; defaultLocation?: string; defaultProviderId?: string; } @@ -23,11 +22,12 @@ export interface ProviderPackageType { export class ManagePackagesDialogModel { private _currentProvider: string; + private _currentLocation: string; /** * A set for locations */ - private _locations: Set = new Set(); + private _locationTypes: Set = new Set(); /** * Map of locations to providers @@ -77,15 +77,10 @@ export class ManagePackagesDialogModel { if (this._options.defaultProviderId && !this._packageManageProviders.has(this._options.defaultProviderId)) { throw new Error(`Invalid default provider id '${this._options.defaultProviderId}`); } - - if (!this._options.multiLocations && !this.defaultLocation) { - throw new Error('Default location not specified for single location mode'); - } } private get defaultOptions(): ManagePackageDialogOptions { return { - multiLocations: true, defaultLocation: undefined, defaultProviderId: undefined }; @@ -120,13 +115,6 @@ export class ManagePackagesDialogModel { return undefined; } - /** - * Returns true if multi locations mode is enabled - */ - public get multiLocationMode(): boolean { - return this.options.multiLocations; - } - /** * Returns options */ @@ -135,17 +123,17 @@ export class ManagePackagesDialogModel { } /** - * returns the array of target locations + * returns the array of target location types */ - public get targetLocations(): string[] { - return Array.from(this._locations.keys()); + public get targetLocationTypes(): string[] { + return Array.from(this._locationTypes.keys()); } /** * Returns the default location */ public get defaultLocation(): string { - return this.options.defaultLocation || this.targetLocations[0]; + return this.options.defaultLocation || this.targetLocationTypes[0]; } /** @@ -164,8 +152,8 @@ export class ManagePackagesDialogModel { for (let index = 0; index < keyArray.length; index++) { const element = this.packageManageProviders.get(keyArray[index]); if (await element.canUseProvider()) { - if (!this._locations.has(element.packageTarget.location)) { - this._locations.add(element.packageTarget.location); + if (!this._locationTypes.has(element.packageTarget.location)) { + this._locationTypes.add(element.packageTarget.location); } if (!this._packageTypes.has(element.packageTarget.location)) { this._packageTypes.set(element.packageTarget.location, []); @@ -205,7 +193,7 @@ export class ManagePackagesDialogModel { public async listPackages(): Promise { let provider = this.currentPackageManageProvider; if (provider) { - return await provider.listPackages(); + return await provider.listPackages(this._currentLocation); } else { throw new Error('Current Provider is not set'); } @@ -222,6 +210,13 @@ export class ManagePackagesDialogModel { } } + /** + * Changes the current location + */ + public changeLocation(location: string): void { + this._currentLocation = location; + } + /** * Installs given packages using current provider * @param packages Packages to install @@ -229,7 +224,7 @@ export class ManagePackagesDialogModel { public async installPackages(packages: IPackageDetails[]): Promise { let provider = this.currentPackageManageProvider; if (provider) { - await provider.installPackages(packages, false); + await provider.installPackages(packages, false, this._currentLocation); } else { throw new Error('Current Provider is not set'); } @@ -238,10 +233,10 @@ export class ManagePackagesDialogModel { /** * Returns the location title for current provider */ - public async getLocationTitle(): Promise { + public async getLocations(): Promise { let provider = this.currentPackageManageProvider; if (provider) { - return await provider.getLocationTitle(); + return await provider.getLocations(); } return Promise.resolve(undefined); } @@ -253,7 +248,7 @@ export class ManagePackagesDialogModel { public async uninstallPackages(packages: IPackageDetails[]): Promise { let provider = this.currentPackageManageProvider; if (provider) { - await provider.uninstallPackages(packages); + await provider.uninstallPackages(packages, this._currentLocation); } else { throw new Error('Current Provider is not set'); } diff --git a/extensions/notebook/src/jupyter/jupyterController.ts b/extensions/notebook/src/jupyter/jupyterController.ts index 3921185430..7873066c33 100644 --- a/extensions/notebook/src/jupyter/jupyterController.ts +++ b/extensions/notebook/src/jupyter/jupyterController.ts @@ -207,7 +207,6 @@ export class JupyterController implements vscode.Disposable { try { if (!options) { options = { - multiLocations: false, defaultLocation: constants.localhostName, defaultProviderId: LocalPipPackageManageProvider.ProviderId }; diff --git a/extensions/notebook/src/jupyter/localCondaPackageManageProvider.ts b/extensions/notebook/src/jupyter/localCondaPackageManageProvider.ts index c5f4b60f44..a0746218cd 100644 --- a/extensions/notebook/src/jupyter/localCondaPackageManageProvider.ts +++ b/extensions/notebook/src/jupyter/localCondaPackageManageProvider.ts @@ -3,7 +3,7 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import { IPackageManageProvider, IPackageDetails, IPackageTarget, IPackageOverview } from '../types'; +import { IPackageManageProvider, IPackageDetails, IPackageTarget, IPackageOverview, IPackageLocation } from '../types'; import { IJupyterServerInstallation } from './jupyterServerInstallation'; import * as constants from '../common/constants'; import * as utils from '../common/utils'; @@ -35,7 +35,7 @@ export class LocalCondaPackageManageProvider implements IPackageManageProvider { /** * Returns list of packages */ - public async listPackages(): Promise { + public async listPackages(location?: string): Promise { return await this.jupyterInstallation.getInstalledCondaPackages(); } @@ -44,7 +44,7 @@ export class LocalCondaPackageManageProvider implements IPackageManageProvider { * @param packages Packages to install * @param useMinVersion minimum version */ - installPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise { + installPackages(packages: IPackageDetails[], useMinVersion: boolean, location?: string): Promise { return this.jupyterInstallation.installCondaPackages(packages, useMinVersion); } @@ -52,7 +52,7 @@ export class LocalCondaPackageManageProvider implements IPackageManageProvider { * Uninstalls given packages * @param packages Packages to uninstall */ - uninstallPackages(packages: IPackageDetails[]): Promise { + uninstallPackages(packages: IPackageDetails[], location?: string): Promise { return this.jupyterInstallation.uninstallCondaPackages(packages); } @@ -66,8 +66,8 @@ export class LocalCondaPackageManageProvider implements IPackageManageProvider { /** * Returns location title */ - getLocationTitle(): Promise { - return Promise.resolve(constants.localhostTitle); + getLocations(): Promise { + return Promise.resolve([{ displayName: constants.localhostTitle, name: constants.localhostName }]); } /** diff --git a/extensions/notebook/src/jupyter/localPipPackageManageProvider.ts b/extensions/notebook/src/jupyter/localPipPackageManageProvider.ts index 1be49dcc78..4d3372e93f 100644 --- a/extensions/notebook/src/jupyter/localPipPackageManageProvider.ts +++ b/extensions/notebook/src/jupyter/localPipPackageManageProvider.ts @@ -3,7 +3,7 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import { IPackageManageProvider, IPackageDetails, IPackageTarget, IPackageOverview } from '../types'; +import { IPackageManageProvider, IPackageDetails, IPackageTarget, IPackageOverview, IPackageLocation } from '../types'; import { IJupyterServerInstallation } from './jupyterServerInstallation'; import * as constants from '../common/constants'; import * as utils from '../common/utils'; @@ -38,7 +38,7 @@ export class LocalPipPackageManageProvider implements IPackageManageProvider { /** * Returns list of packages */ - public async listPackages(): Promise { + public async listPackages(location?: string): Promise { return await this.jupyterInstallation.getInstalledPipPackages(); } @@ -47,7 +47,7 @@ export class LocalPipPackageManageProvider implements IPackageManageProvider { * @param packages Packages to install * @param useMinVersion minimum version */ - installPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise { + installPackages(packages: IPackageDetails[], useMinVersion: boolean, location?: string): Promise { return this.jupyterInstallation.installPipPackages(packages, useMinVersion); } @@ -55,7 +55,7 @@ export class LocalPipPackageManageProvider implements IPackageManageProvider { * Uninstalls given packages * @param packages Packages to uninstall */ - uninstallPackages(packages: IPackageDetails[]): Promise { + uninstallPackages(packages: IPackageDetails[], location?: string): Promise { return this.jupyterInstallation.uninstallPipPackages(packages); } @@ -69,8 +69,8 @@ export class LocalPipPackageManageProvider implements IPackageManageProvider { /** * Returns location title */ - getLocationTitle(): Promise { - return Promise.resolve(constants.localhostTitle); + getLocations(): Promise { + return Promise.resolve([{ displayName: constants.localhostTitle, name: constants.localhostName }]); } /** diff --git a/extensions/notebook/src/test/managePackages/managePackagesDialog.test.ts b/extensions/notebook/src/test/managePackages/managePackagesDialog.test.ts new file mode 100644 index 0000000000..46e5df1798 --- /dev/null +++ b/extensions/notebook/src/test/managePackages/managePackagesDialog.test.ts @@ -0,0 +1,295 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the Source EULA. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ +import * as azdata from 'azdata'; +import * as vscode from 'vscode'; +import * as TypeMoq from 'typemoq'; +import { ManagePackagesDialog } from '../../dialog/managePackages/managePackagesDialog'; +import { ManagePackagesDialogModel } from '../../dialog/managePackages/managePackagesDialogModel'; +import { IPackageManageProvider, IPackageLocation } from '../../types'; +import { LocalCondaPackageManageProvider } from '../../jupyter/localCondaPackageManageProvider'; +import { InstalledPackagesTab } from '../../dialog/managePackages/installedPackagesTab'; +import should = require('should'); + +interface TestContext { + view: azdata.ModelView; + onClick: vscode.EventEmitter; + dialog: TypeMoq.IMock; + model: TypeMoq.IMock; +} + +describe('Manage Package Dialog', () => { + + it('getLocationComponent should create text component for one location', async function (): Promise { + let testContext = createViewContext(); + let locations = [ + { + displayName: 'dl1', + name: 'nl1' + } + ]; + testContext.model.setup(x => x.getLocations()).returns(() => Promise.resolve(locations)); + testContext.model.setup(x => x.changeLocation('nl1')); + testContext.dialog.setup(x => x.changeLocation('nl1')); + + let actual = await InstalledPackagesTab.getLocationComponent(testContext.view, testContext.dialog.object); + should.equal('onTextChanged' in actual, true); + testContext.dialog.verify(x => x.changeLocation('nl1'), TypeMoq.Times.once()); + }); + + it('getLocationComponent should create text component for undefined location', async function (): Promise { + let testContext = createViewContext(); + let locations: IPackageLocation[] | undefined = undefined; + testContext.model.setup(x => x.getLocations()).returns(() => Promise.resolve(locations)); + + let actual = await InstalledPackagesTab.getLocationComponent(testContext.view, testContext.dialog.object); + should.equal('onTextChanged' in actual, true); + }); + + it('getLocationComponent should create drop down component for more than one location', async function (): Promise { + let testContext = createViewContext(); + let locations = [ + { + displayName: 'dl1', + name: 'nl1' + }, + { + displayName: 'dl2', + name: 'nl2' + } + ]; + testContext.model.setup(x => x.getLocations()).returns(() => Promise.resolve(locations)); + testContext.dialog.setup(x => x.changeLocation('nl1')); + testContext.dialog.setup(x => x.resetPages()).returns(() => Promise.resolve()); + + let actual = await InstalledPackagesTab.getLocationComponent(testContext.view, testContext.dialog.object); + should.equal('onValueChanged' in actual, true); + testContext.dialog.verify(x => x.changeLocation('nl1'), TypeMoq.Times.once()); + (actual).value = { + displayName: 'dl2', + name: 'nl2' + }; + testContext.onClick.fire(); + testContext.dialog.verify(x => x.changeLocation('nl2'), TypeMoq.Times.once()); + testContext.dialog.verify(x => x.resetPages(), TypeMoq.Times.once()); + + }); + + it('getLocationComponent should show error if reset pages fails', async function (): Promise { + let testContext = createViewContext(); + let locations = [ + { + displayName: 'dl1', + name: 'nl1' + }, + { + displayName: 'dl2', + name: 'nl2' + } + ]; + testContext.model.setup(x => x.getLocations()).returns(() => Promise.resolve(locations)); + testContext.dialog.setup(x => x.changeLocation('nl1')); + testContext.dialog.setup(x => x.resetPages()).throws(new Error('failed')); + testContext.dialog.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve()); + + let actual = await InstalledPackagesTab.getLocationComponent(testContext.view, testContext.dialog.object); + should.equal('onValueChanged' in actual, true); + testContext.dialog.verify(x => x.changeLocation('nl1'), TypeMoq.Times.once()); + (actual).value = { + displayName: 'dl2', + name: 'nl2' + }; + testContext.onClick.fire(); + testContext.dialog.verify(x => x.changeLocation('nl2'), TypeMoq.Times.once()); + testContext.dialog.verify(x => x.showErrorMessage(TypeMoq.It.isAny()), TypeMoq.Times.once()); + + }); + + function createViewContext(): TestContext { + let packageManageProviders = new Map(); + packageManageProviders.set(LocalCondaPackageManageProvider.ProviderId, new LocalCondaPackageManageProvider(undefined)); + let model = TypeMoq.Mock.ofInstance(new ManagePackagesDialogModel(undefined, packageManageProviders)); + let dialog = TypeMoq.Mock.ofInstance(new ManagePackagesDialog(model.object)); + dialog.setup(x => x.model).returns(() => model.object); + + let onClick: vscode.EventEmitter = new vscode.EventEmitter(); + + let componentBase: azdata.Component = { + id: '', + updateProperties: () => Promise.resolve(), + updateProperty: () => Promise.resolve(), + updateCssStyles: undefined!, + onValidityChanged: undefined!, + valid: true, + validate: undefined!, + focus: undefined! + }; + let button: azdata.ButtonComponent = Object.assign({}, componentBase, { + onDidClick: onClick.event + }); + let radioButton: azdata.RadioButtonComponent = Object.assign({}, componentBase, { + onDidClick: onClick.event + }); + const components: azdata.Component[] = []; + let container = { + clearItems: () => { }, + addItems: () => { }, + addItem: () => { }, + removeItem: () => true, + insertItem: () => { }, + items: components, + setLayout: () => { } + }; + let form: azdata.FormContainer = Object.assign({}, componentBase, container, { + }); + let flex: azdata.FlexContainer = Object.assign({}, componentBase, container, { + }); + + let buttonBuilder: azdata.ComponentBuilder = { + component: () => button, + withProperties: () => buttonBuilder, + withValidation: () => buttonBuilder + }; + let radioButtonBuilder: azdata.ComponentBuilder = { + component: () => radioButton, + withProperties: () => radioButtonBuilder, + withValidation: () => radioButtonBuilder + }; + let inputBox: () => azdata.InputBoxComponent = () => Object.assign({}, componentBase, { + onTextChanged: undefined!, + onEnterKeyPressed: undefined!, + value: '' + }); + let image: () => azdata.ImageComponent = () => Object.assign({}, componentBase, { + + }); + let dropdown: () => azdata.DropDownComponent = () => Object.assign({}, componentBase, { + onValueChanged: onClick.event, + value: { + name: '', + displayName: '' + }, + values: [] + }); + let declarativeTable: () => azdata.DeclarativeTableComponent = () => Object.assign({}, componentBase, { + onDataChanged: undefined!, + data: [], + columns: [] + }); + + let loadingComponent: () => azdata.LoadingComponent = () => Object.assign({}, componentBase, { + loading: false, + component: undefined! + }); + + let declarativeTableBuilder: azdata.ComponentBuilder = { + component: () => declarativeTable(), + withProperties: () => declarativeTableBuilder, + withValidation: () => declarativeTableBuilder + }; + + let loadingBuilder: azdata.LoadingComponentBuilder = { + component: () => loadingComponent(), + withProperties: () => loadingBuilder, + withValidation: () => loadingBuilder, + withItem: () => loadingBuilder + }; + + let formBuilder: azdata.FormBuilder = Object.assign({}, { + component: () => form, + addFormItem: () => { }, + insertFormItem: () => { }, + removeFormItem: () => true, + addFormItems: () => { }, + withFormItems: () => formBuilder, + withProperties: () => formBuilder, + withValidation: () => formBuilder, + withItems: () => formBuilder, + withLayout: () => formBuilder + }); + + let flexBuilder: azdata.FlexBuilder = Object.assign({}, { + component: () => flex, + withProperties: () => flexBuilder, + withValidation: () => flexBuilder, + withItems: () => flexBuilder, + withLayout: () => flexBuilder + }); + + let inputBoxBuilder: azdata.ComponentBuilder = { + component: () => { + let r = inputBox(); + return r; + }, + withProperties: () => inputBoxBuilder, + withValidation: () => inputBoxBuilder + }; + let imageBuilder: azdata.ComponentBuilder = { + component: () => { + let r = image(); + return r; + }, + withProperties: () => imageBuilder, + withValidation: () => imageBuilder + }; + let dropdownBuilder: azdata.ComponentBuilder = { + component: () => { + let r = dropdown(); + return r; + }, + withProperties: () => dropdownBuilder, + withValidation: () => dropdownBuilder + }; + + let view: azdata.ModelView = { + onClosed: undefined!, + connection: undefined!, + serverInfo: undefined!, + valid: true, + onValidityChanged: undefined!, + validate: undefined!, + initializeModel: () => { return Promise.resolve(); }, + modelBuilder: { + radioCardGroup: undefined!, + navContainer: undefined!, + divContainer: undefined!, + flexContainer: () => flexBuilder, + splitViewContainer: undefined!, + dom: undefined!, + card: undefined!, + inputBox: () => inputBoxBuilder, + checkBox: undefined!, + radioButton: () => radioButtonBuilder, + webView: undefined!, + editor: undefined!, + diffeditor: undefined!, + text: () => inputBoxBuilder, + image: () => imageBuilder, + button: () => buttonBuilder, + dropDown: () => dropdownBuilder, + tree: undefined!, + listBox: undefined!, + table: undefined!, + declarativeTable: () => declarativeTableBuilder, + dashboardWidget: undefined!, + dashboardWebview: undefined!, + formContainer: () => formBuilder, + groupContainer: undefined!, + toolbarContainer: undefined!, + loadingComponent: () => loadingBuilder, + fileBrowserTree: undefined!, + hyperlink: undefined!, + tabbedPanel: undefined!, + separator: undefined! + } + }; + + return { + dialog: dialog, + model: model, + view: view, + onClick: onClick, + }; + } +}); diff --git a/extensions/notebook/src/test/managePackages/managePackagesDialogModel.test.ts b/extensions/notebook/src/test/managePackages/managePackagesDialogModel.test.ts index 8fe0deb61e..1c2b226457 100644 --- a/extensions/notebook/src/test/managePackages/managePackagesDialogModel.test.ts +++ b/extensions/notebook/src/test/managePackages/managePackagesDialogModel.test.ts @@ -50,7 +50,6 @@ describe('Manage Packages', () => { providers.set(provider.providerId, provider); let options = { - multiLocations: true, defaultLocation: 'invalid location' }; let model = new ManagePackagesDialogModel(jupyterServerInstallation, providers, options); @@ -64,29 +63,12 @@ describe('Manage Packages', () => { providers.set(provider.providerId, provider); let options = { - multiLocations: true, defaultProviderId: 'invalid provider' }; let model = new ManagePackagesDialogModel(jupyterServerInstallation, providers, options); await should(model.init()).rejectedWith(`Invalid default provider id '${options.defaultProviderId}`); }); - /* Test disabled. Tracking issue: https://github.com/microsoft/azuredatastudio/issues/8877 - it('Init should throw exception not given valid default location for single location mode', async function (): Promise { - let testContext = createContext(); - let provider = createProvider(testContext); - let providers = new Map(); - providers.set(provider.providerId, provider); - - let options = { - multiLocations: false - }; - let model = new ManagePackagesDialogModel(jupyterServerInstallation, providers, options); - await should(model.init()).rejectedWith(`Default location not specified for single location mode`); - }); - */ - - it('Init should set default options given undefined', async function (): Promise { let testContext = createContext(); let provider = createProvider(testContext); @@ -96,7 +78,6 @@ describe('Manage Packages', () => { let model = new ManagePackagesDialogModel(jupyterServerInstallation, providers, undefined); await model.init(); - should.equal(model.multiLocationMode, true); should.equal(model.defaultLocation, provider.packageTarget.location); should.equal(model.defaultProviderId, provider.providerId); }); @@ -119,14 +100,12 @@ describe('Manage Packages', () => { providers.set(testContext1.provider.providerId, createProvider(testContext1)); providers.set(testContext2.provider.providerId, createProvider(testContext2)); let options = { - multiLocations: false, defaultLocation: testContext2.provider.packageTarget.location, defaultProviderId: testContext2.provider.providerId }; let model = new ManagePackagesDialogModel(jupyterServerInstallation, providers, options); await model.init(); - should.equal(model.multiLocationMode, false); should.equal(model.defaultLocation, testContext2.provider.packageTarget.location); should.equal(model.defaultProviderId, testContext2.provider.providerId); }); @@ -195,7 +174,7 @@ describe('Manage Packages', () => { it('changeProvider should change current provider successfully', async function (): Promise { let testContext1 = createContext(); testContext1.provider.providerId = 'providerId1'; - testContext1.provider.getLocationTitle = () => Promise.resolve('location title 1'); + testContext1.provider.getLocations = () => Promise.resolve([{displayName: 'location title 1', name: 'location1'}]); testContext1.provider.packageTarget = { location: 'location1', packageType: 'package-type1' @@ -203,7 +182,7 @@ describe('Manage Packages', () => { let testContext2 = createContext(); testContext2.provider.providerId = 'providerId2'; - testContext2.provider.getLocationTitle = () => Promise.resolve('location title 2'); + testContext2.provider.getLocations = () => Promise.resolve([{displayName: 'location title 2', name: 'location2'}]); testContext2.provider.packageTarget = { location: 'location2', packageType: 'package-type2' @@ -217,7 +196,7 @@ describe('Manage Packages', () => { await model.init(); model.changeProvider('providerId2'); - should.deepEqual(await model.getLocationTitle(), 'location title 2'); + should.deepEqual(await model.getLocations(), [{displayName: 'location title 2', name: 'location2'}]); }); it('changeProvider should throw exception given invalid provider', async function (): Promise { @@ -283,7 +262,7 @@ describe('Manage Packages', () => { let testContext2 = createContext(); testContext2.provider.providerId = 'providerId2'; - testContext2.provider.getLocationTitle = () => Promise.resolve('location title 2'); + testContext2.provider.getLocations = () => Promise.resolve([{displayName: 'location title 2', name: 'location2'}]); testContext2.provider.packageTarget = { location: 'location2', packageType: 'package-type2' @@ -301,6 +280,12 @@ describe('Manage Packages', () => { testContext2.provider.listPackages = () => { return Promise.resolve(packages); }; + testContext1.provider.listPackages = () => { + return Promise.resolve([{ + name: 'p3', + version: '1.1.1.3' + }]); + }; let providers = new Map(); providers.set(testContext1.provider.providerId, createProvider(testContext1)); @@ -315,7 +300,50 @@ describe('Manage Packages', () => { await should(model.installPackages(packages)).resolved(); await should(model.uninstallPackages(packages)).resolved(); await should(model.getPackageOverview('p1')).resolved(); - await should(model.getLocationTitle()).resolvedWith('location title 2'); + await should(model.getLocations()).resolvedWith([{displayName: 'location title 2', name: 'location2'}]); + }); + + it('listPackages should return packages for current location', async function (): Promise { + let testContext = createContext(); + testContext.provider.providerId = 'providerId1'; + testContext.provider.packageTarget = { + location: 'location1', + packageType: 'package-type1' + }; + + let packages1 = [ + { + name: 'p1', + version: '1.1.1.1' + }, + { + name: 'p2', + version: '1.1.1.2' + } + ]; + let packages2 = [{ + name: 'p3', + version: '1.1.1.3' + }]; + testContext.provider.listPackages = (location) => { + if (location === 'location1') { + return Promise.resolve(packages1); + } else { + return Promise.resolve(packages2); + } + + }; + + let providers = new Map(); + providers.set(testContext.provider.providerId, createProvider(testContext)); + + let model = new ManagePackagesDialogModel(jupyterServerInstallation, providers, undefined); + + await model.init(); + model.changeProvider('providerId1'); + model.changeLocation('location2'); + + await should(model.listPackages()).resolvedWith(packages2); }); function createContext(): TestContext { @@ -327,7 +355,7 @@ describe('Manage Packages', () => { packageType: 'package-type' }, canUseProvider: () => { return Promise.resolve(true); }, - getLocationTitle: () => { return Promise.resolve('location-title'); }, + getLocations: () => { return Promise.resolve([{displayName: 'location-title', name: 'location'}]); }, installPackages:() => { return Promise.resolve(); }, uninstallPackages: (packages: IPackageDetails[]) => { return Promise.resolve(); }, listPackages: () => { return Promise.resolve([]); }, @@ -339,10 +367,10 @@ describe('Manage Packages', () => { function createProvider(testContext: TestContext): IPackageManageProvider { let mockProvider = TypeMoq.Mock.ofType(LocalPipPackageManageProvider); mockProvider.setup(x => x.canUseProvider()).returns(() => testContext.provider.canUseProvider()); - mockProvider.setup(x => x.getLocationTitle()).returns(() => testContext.provider.getLocationTitle()); - mockProvider.setup(x => x.installPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((packages, useMinVersion) => testContext.provider.installPackages(packages, useMinVersion)); - mockProvider.setup(x => x.uninstallPackages(TypeMoq.It.isAny())).returns((packages) => testContext.provider.uninstallPackages(packages)); - mockProvider.setup(x => x.listPackages()).returns(() => testContext.provider.listPackages()); + mockProvider.setup(x => x.getLocations()).returns(() => testContext.provider.getLocations()); + mockProvider.setup(x => x.installPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((packages, useMinVersion) => testContext.provider.installPackages(packages, useMinVersion)); + mockProvider.setup(x => x.uninstallPackages(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((packages) => testContext.provider.uninstallPackages(packages)); + mockProvider.setup(x => x.listPackages(TypeMoq.It.isAny())).returns(() => testContext.provider.listPackages()); mockProvider.setup(x => x.getPackageOverview(TypeMoq.It.isAny())).returns((name) => testContext.provider.getPackageOverview(name)); mockProvider.setup(x => x.packageTarget).returns(() => testContext.provider.packageTarget); mockProvider.setup(x => x.providerId).returns(() => testContext.provider.providerId); diff --git a/extensions/notebook/src/types.d.ts b/extensions/notebook/src/types.d.ts index 368d221a1b..984a47d7a5 100644 --- a/extensions/notebook/src/types.d.ts +++ b/extensions/notebook/src/types.d.ts @@ -65,6 +65,14 @@ export interface IPackageDetails { version: string; } +/** + * Package location + */ +export interface IPackageLocation { + name: string; + displayName: string; +} + /** * Package target interface */ @@ -99,20 +107,22 @@ export interface IPackageManageProvider { /** * Returns list of installed packages */ - listPackages(): Promise; + listPackages(location?: string): Promise; /** * Installs give packages * @param package Packages to install * @param useMinVersion if true, minimal version will be used + * @param location package location */ - installPackages(package: IPackageDetails[], useMinVersion: boolean): Promise; + installPackages(package: IPackageDetails[], useMinVersion: boolean, location?: string): Promise; /** * Uninstalls given packages * @param package package to uninstall + * @param location package location */ - uninstallPackages(package: IPackageDetails[]): Promise; + uninstallPackages(package: IPackageDetails[], location?: string): Promise; /** * Returns true if the provider can be used in current context @@ -122,7 +132,7 @@ export interface IPackageManageProvider { /** * Returns location title */ - getLocationTitle(): Promise; + getLocations(): Promise; /** * Returns Package Overview