Machine Learning Services extension with package management feature (#8622)

* Machine Learning Services extension with package management feature
This commit is contained in:
Leila Lali
2019-12-17 09:55:42 -08:00
committed by GitHub
parent ef8c0e91e6
commit 0a6dc2720d
34 changed files with 3923 additions and 11 deletions

View File

@@ -204,7 +204,8 @@ const externalExtensions = [
'cms',
'query-history',
'liveshare',
'sql-database-projects'
'sql-database-projects',
'machine-learning-services'
];
// extensions that require a rebuild since they have native parts
const rebuildExtensions = [

View File

@@ -240,7 +240,8 @@ const externalExtensions = [
'cms',
'query-history',
'liveshare',
'sql-database-projects'
'sql-database-projects',
'machine-learning-services'
];
// extensions that require a rebuild since they have native parts

View File

@@ -0,0 +1,2 @@
*.vsix
python/**

View File

@@ -0,0 +1,4 @@
src/**
tsconfig.json
python/**
out/test/**

View File

@@ -0,0 +1,28 @@
# Machine Learning Services for Azure Data Studio #
Machine Learning Services for Azure Data Studio (Preview) provides support for new features that help you create, build, and deploy machine learning jobs in SQL Server through Azure Data Studio.
## Features ##
* Enable Machine Learning Services on SQL Server.
* Deploy an MLFlow container to track models.
* Follow along with machine learning notebooks.
## Prerequisites ##
In order to use Machine Learning Services for Azure Data Studio (Preview), your SQL Server must have Machine Learning Services installed. Follow the instructions [here](https://docs.microsoft.com/sql/advanced-analytics/install/sql-machine-learning-services-windows-install?view=sql-server-ver15) to do so.
## Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
## Privacy Statement
The [Microsoft Enterprise and Developer Privacy Statement](https://privacy.microsoft.com/privacystatement) describes the privacy statement of this software.
## License
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the [Source EULA](https://raw.githubusercontent.com/Microsoft/azuredatastudio/master/LICENSE.txt).

View File

@@ -0,0 +1,6 @@
{
"requiredPythonPackages": [
{ "name": "pymssql", "version": "2.1.4" },
{ "name": "sqlmlutils", "version": ""}
]
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,54 @@
{
"name": "machineLearningServices",
"displayName": "%displayName%",
"description": "%description%",
"version": "1.0.0",
"publisher": "Microsoft",
"preview": true,
"engines": {
"vscode": "^1.25.0",
"azdata": ">=1.13.0"
},
"activationEvents": [
"onCommand:ml.command.managePackages"
],
"license": "https://raw.githubusercontent.com/Microsoft/azuredatastudio/master/LICENSE.txt",
"icon": "images/ML_ExtensionIcon.png",
"aiKey": "AIF-37eefaf0-8022-4671-a3fb-64752724682e",
"main": "./out/main",
"repository": {
"type": "git",
"url": "https://github.com/Microsoft/azuredatastudio.git"
},
"extensionDependencies": [
"Microsoft.mssql",
"Microsoft.notebook"
],
"contributes": {
"commands": [
{
"command": "ml.command.managePackages",
"title": "%ml.command.managePackages%"
}
]
},
"dependencies": {
"vscode-nls": "^4.0.0"
},
"devDependencies": {
"@types/mocha": "^5.2.5",
"@types/node": "^10.14.8",
"@types/uuid": "^3.4.5",
"mocha": "^5.2.0",
"mocha-junit-reporter": "^1.17.0",
"mocha-multi-reporters": "^1.1.7",
"should": "^13.2.1",
"typemoq": "^2.1.0",
"vscode": "1.1.26"
},
"__metadata": {
"id": "56",
"publisherDisplayName": "Microsoft",
"publisherId": "Microsoft"
}
}

View File

@@ -0,0 +1,10 @@
{
"displayName": "SQL Server Machine Learning Services",
"description": "SQL Server Machine Learning Services",
"mlServices.enable": "Enable Machine Learning Services",
"mlServices.disable": "Disable Machine Learning Services",
"title.tasks": "Getting Started",
"title.endpoints": "Endpoints",
"title.books": "Machine Learning Services Books",
"ml.command.managePackages": "Manage Packages in SQL Server"
}

View File

@@ -0,0 +1,65 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
/**
* Wrapper class to act as a facade over VSCode and Data APIs and allow us to test / mock callbacks into
* this API from our code
*/
export class ApiWrapper {
public createOutputChannel(name: string): vscode.OutputChannel {
return vscode.window.createOutputChannel(name);
}
public createTerminalWithOptions(options: vscode.TerminalOptions): vscode.Terminal {
return vscode.window.createTerminal(options);
}
public getCurrentConnection(): Thenable<azdata.connection.ConnectionProfile> {
return azdata.connection.getCurrentConnection();
}
public getCredentials(connectionId: string): Thenable<{ [name: string]: string }> {
return azdata.connection.getCredentials(connectionId);
}
public registerCommand(command: string, callback: (...args: any[]) => any, thisArg?: any): vscode.Disposable {
return vscode.commands.registerCommand(command, callback, thisArg);
}
public executeCommand<T>(command: string, ...rest: any[]): Thenable<T | undefined> {
return vscode.commands.executeCommand(command, ...rest);
}
public getUriForConnection(connectionId: string): Thenable<string> {
return azdata.connection.getUriForConnection(connectionId);
}
public getProvider<T extends azdata.DataProvider>(providerId: string, providerType: azdata.DataProviderType): T {
return azdata.dataprotocol.getProvider<T>(providerId, providerType);
}
public showErrorMessage(message: string, ...items: string[]): Thenable<string | undefined> {
return vscode.window.showErrorMessage(message, ...items);
}
public showInfoMessage(message: string, ...items: string[]): Thenable<string | undefined> {
return vscode.window.showInformationMessage(message, ...items);
}
public showOpenDialog(options: vscode.OpenDialogOptions): Thenable<vscode.Uri[] | undefined> {
return vscode.window.showOpenDialog(options);
}
public startBackgroundOperation(operationInfo: azdata.BackgroundOperationInfo): void {
azdata.tasks.startBackgroundOperation(operationInfo);
}
public getExtension(extensionId: string): vscode.Extension<any> | undefined {
return vscode.extensions.getExtension(extensionId);
}
}

View File

@@ -0,0 +1,36 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { promises as fs } from 'fs';
import * as path from 'path';
import * as nbExtensionApis from '../typings/notebookServices';
const configFileName = 'config.json';
/**
* Extension Configuration
*/
export class Config {
private _configValues: any;
constructor(private _root: string) {
}
/**
* Loads the config values
*/
public async load(): Promise<void> {
const rawConfig = await fs.readFile(path.join(this._root, configFileName));
this._configValues = JSON.parse(rawConfig.toString());
}
/**
* Returns the config value of required packages
*/
public get requiredPythonPackages(): nbExtensionApis.IPackageDetails[] {
return this._configValues.requiredPythonPackages;
}
}

View File

@@ -0,0 +1,28 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as nls from 'vscode-nls';
const localize = nls.loadMessageBundle();
export const winPlatform = 'win32';
export const pythonBundleVersion = '0.0.1';
export const managePackagesCommand = 'jupyter.cmd.managePackages';
export const mlManagePackagesCommand = 'ml.command.managePackages';
export const extensionOutputChannel = 'Machine Learning Services';
export const notebookExtensionName = 'Microsoft.notebook';
// Localized texts
//
export const managePackageCommandError = localize('ml.managePackages.error', "Either no connection is available or the server does not have external script enabled.");
export function installDependenciesError(err: string): string { return localize('ml.installDependencies.error', "Failed to install dependencies. Error: {0}", err); }
export const installDependenciesMsgTaskName = localize('ml.installDependencies.msgTaskName', "Installing Machine Learning extension dependencies");
export const installDependenciesPackages = localize('ml.installDependencies.packages', "Installing required packages ...");
export const installDependenciesPackagesAlreadyInstalled = localize('ml.installDependencies.packagesAlreadyInstalled', "Required packages are already installed.");
export function installDependenciesGetPackagesError(err: string): string { return localize('ml.installDependencies.getPackagesError', "Failed to get installed python packages. Error: {0}", err); }
export const packageManagerNoConnection = localize('ml.packageManager.NoConnection', "No connection selected");
export const notebookExtensionNotLoaded = localize('ml.notebookExtensionNotLoaded', "Notebook extension is not loaded");

View File

@@ -0,0 +1,81 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as vscode from 'vscode';
import * as childProcess from 'child_process';
const ExecScriptsTimeoutInSeconds = 600000;
export class ProcessService {
public async execScripts(exeFilePath: string, scripts: string[], outputChannel?: vscode.OutputChannel): Promise<void> {
return new Promise<void>((resolve, reject) => {
const scriptExecution = childProcess.spawn(exeFilePath);
let output: string;
scripts.forEach(script => {
scriptExecution.stdin.write(`${script}\n`);
});
scriptExecution.stdin.end();
// Add listeners to print stdout and stderr if an output channel was provided
if (outputChannel) {
scriptExecution.stdout.on('data', data => {
this.outputDataChunk(data, outputChannel, ' stdout: ');
output = output + data.toString();
});
scriptExecution.stderr.on('data', data => {
this.outputDataChunk(data, outputChannel, ' stderr: ');
output = output + data.toString();
});
}
scriptExecution.on('exit', (code) => {
if (code === 0) {
resolve();
} else {
reject(`Process exited with code: ${code}. output: ${output}`);
}
});
setTimeout(() => {
try {
scriptExecution.kill();
} catch (error) {
console.log(error);
}
}, ExecScriptsTimeoutInSeconds);
});
}
public async executeBufferedCommand(cmd: string, outputChannel?: vscode.OutputChannel): Promise<string> {
return new Promise<string>((resolve, reject) => {
if (outputChannel) {
outputChannel.appendLine(` > ${cmd}`);
}
let child = childProcess.exec(cmd, (err, stdout) => {
if (err) {
reject(err);
} else {
resolve(stdout);
}
});
// Add listeners to print stdout and stderr if an output channel was provided
if (outputChannel) {
child.stdout.on('data', data => { this.outputDataChunk(data, outputChannel, ' stdout: '); });
child.stderr.on('data', data => { this.outputDataChunk(data, outputChannel, ' stderr: '); });
}
});
}
private outputDataChunk(data: string | Buffer, outputChannel: vscode.OutputChannel, header: string): void {
data.toString().split(/\r?\n/)
.forEach(line => {
outputChannel.appendLine(header + line);
});
}
}

View File

@@ -0,0 +1,124 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as azdata from 'azdata';
import * as nbExtensionApis from '../typings/notebookServices';
import { ApiWrapper } from './apiWrapper';
const listPythonPackagesQuery = `
EXEC sp_execute_external_script
@language=N'Python',
@script=N'import pkg_resources
import pandas
OutputDataSet = pandas.DataFrame([(d.project_name, d.version) for d in pkg_resources.working_set])'
`;
const checkMlInstalledQuery = `
Declare @tablevar table(name NVARCHAR(MAX), min INT, max INT, config_value bit, run_value bit)
insert into @tablevar(name, min, max, config_value, run_value) exec sp_configure
Declare @external_script_enabled bit
SELECT @external_script_enabled=config_value FROM @tablevar WHERE name = 'external scripts enabled'
SELECT @external_script_enabled`;
const checkPythonInstalledQuery = `
SELECT is_installed
FROM sys.dm_db_external_language_stats s, sys.external_languages l
WHERE s.external_language_id = l.external_language_id AND language = 'Python'`;
const modifyExternalScriptConfigQuery = `
EXEC sp_configure 'external scripts enabled', #CONFIG_VALUE#;
RECONFIGURE WITH OVERRIDE;
Declare @tablevar table(name NVARCHAR(MAX), min INT, max INT, config_value bit, run_value bit)
insert into @tablevar(name, min, max, config_value, run_value) exec sp_configure
Declare @external_script_enabled bit
SELECT @external_script_enabled=config_value FROM @tablevar WHERE name = 'external scripts enabled'
SELECT @external_script_enabled`;
/**
* SQL Query runner
*/
export class QueryRunner {
constructor(private _apiWrapper: ApiWrapper) {
}
/**
* Returns python packages installed in SQL server instance
* @param connection SQL Connection
*/
public async getPythonPackages(connection: azdata.connection.ConnectionProfile): Promise<nbExtensionApis.IPackageDetails[]> {
let packages: nbExtensionApis.IPackageDetails[] = [];
let result = await this.runQuery(connection, listPythonPackagesQuery);
if (result && result.rows.length > 0) {
packages = result.rows.map(row => {
return {
name: row[0].displayValue,
version: row[1].displayValue
};
});
}
return packages;
}
/**
* Updates External Script Config in a SQL server instance
* @param connection SQL Connection
* @param enable if true the config will be enabled otherwise it will be disabled
*/
public async updateExternalScriptConfig(connection: azdata.connection.ConnectionProfile, enable: boolean): Promise<void> {
let query = modifyExternalScriptConfigQuery;
let configValue = enable ? '1' : '0';
query = query.replace('#CONFIG_VALUE#', configValue);
await this.runQuery(connection, query);
}
/**
* Returns true if python installed in the give SQL server instance
*/
public async isPythonInstalled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
let result = await this.runQuery(connection, checkPythonInstalledQuery);
let isInstalled = false;
if (result && result.rows && result.rows.length > 0) {
isInstalled = result.rows[0][0].displayValue === '1';
}
return isInstalled;
}
/**
* Returns true if mls is installed in the give SQL server instance
*/
public async isMachineLearningServiceEnabled(connection: azdata.connection.ConnectionProfile): Promise<boolean> {
let result = await this.runQuery(connection, checkMlInstalledQuery);
let isEnabled = false;
if (result && result.rows && result.rows.length > 0) {
isEnabled = result.rows[0][0].displayValue === '1';
}
return isEnabled;
}
private async runQuery(connection: azdata.connection.ConnectionProfile, query: string): Promise<azdata.SimpleExecuteResult | undefined> {
let result: azdata.SimpleExecuteResult | undefined = undefined;
try {
if (connection) {
let connectionUri = await this._apiWrapper.getUriForConnection(connection.connectionId);
let queryProvider = this._apiWrapper.getProvider<azdata.QueryProvider>(connection.providerId, azdata.DataProviderType.QueryProvider);
if (queryProvider) {
result = await queryProvider.runQueryAndReturn(connectionUri, query);
}
}
} catch (error) {
console.log(error);
}
return result;
}
}

View File

@@ -0,0 +1,48 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as uuid from 'uuid';
import * as path from 'path';
import * as os from 'os';
import * as fs from 'fs';
import * as constants from '../common/constants';
import { promisify } from 'util';
export async function execCommandOnTempFile<T>(content: string, command: (filePath: string) => Promise<T>): Promise<T> {
let tempFilePath: string = '';
try {
tempFilePath = path.join(os.tmpdir(), `ads_ml_temp_${uuid.v4()}`);
await fs.promises.writeFile(tempFilePath, content);
let result = await command(tempFilePath);
return result;
}
finally {
await fs.promises.unlink(tempFilePath);
}
}
export async function exists(path: string): Promise<boolean> {
return promisify(fs.exists)(path);
}
export async function createFolder(dirPath: string): Promise<void> {
let folderExists = await exists(dirPath);
if (!folderExists) {
await fs.promises.mkdir(dirPath);
}
}
export function getPythonInstallationLocation(rootFolder: string) {
return path.join(rootFolder, 'python');
}
export function getPythonExePath(rootFolder: string): string {
return path.join(
getPythonInstallationLocation(rootFolder),
constants.pythonBundleVersion,
process.platform === constants.winPlatform ? 'python.exe' : 'bin/python3');
}

View File

@@ -0,0 +1,112 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as vscode from 'vscode';
import * as nbExtensionApis from '../typings/notebookServices';
import { PackageManager } from '../packageManagement/packageManager';
import * as constants from '../common/constants';
import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import { ProcessService } from '../common/processService';
import { Config } from '../common/config';
/**
* The main controller class that initializes the extension
*/
export default class MainController implements vscode.Disposable {
private _outputChannel: vscode.OutputChannel;
private _rootPath = this._context.extensionPath;
private _config: Config;
public constructor(
private _context: vscode.ExtensionContext,
private _apiWrapper: ApiWrapper,
private _queryRunner: QueryRunner,
private _processService: ProcessService,
private _packageManager?: PackageManager
) {
this._outputChannel = this._apiWrapper.createOutputChannel(constants.extensionOutputChannel);
this._rootPath = this._context.extensionPath;
this._config = new Config(this._rootPath);
}
/**
* Deactivates the extension
*/
public deactivate(): void {
}
/**
* Activates the extension
*/
public async activate(): Promise<boolean> {
await this.initialize();
return Promise.resolve(true);
}
/**
* Returns an instance of Server Installation from notebook extension
*/
private async getNotebookExtensionApis(): Promise<nbExtensionApis.IExtensionApi> {
let nbExtension = this._apiWrapper.getExtension(constants.notebookExtensionName);
if (nbExtension) {
await nbExtension.activate();
return (nbExtension.exports as nbExtensionApis.IExtensionApi);
} else {
throw new Error(constants.notebookExtensionNotLoaded);
}
}
private async initialize(): Promise<void> {
this._outputChannel.show(true);
let nbApis = await this.getNotebookExtensionApis();
await this._config.load();
let packageManager = this.getPackageManager(nbApis);
this._apiWrapper.registerCommand(constants.mlManagePackagesCommand, (async () => {
await packageManager.managePackages();
}));
try {
await packageManager.installDependencies();
} catch (err) {
this._outputChannel.appendLine(err);
}
}
/**
* Returns the package manager instance
*/
public getPackageManager(nbApis: nbExtensionApis.IExtensionApi): PackageManager {
if (!this._packageManager) {
this._packageManager = new PackageManager(nbApis, this._outputChannel, this._rootPath, this._apiWrapper, this._queryRunner, this._processService, this._config);
this._packageManager.init();
}
return this._packageManager;
}
/**
* Package manager instance
*/
public set packageManager(value: PackageManager) {
this._packageManager = value;
}
/**
* Config instance
*/
public get config(): Config {
return this._config;
}
/**
* Disposes the extension
*/
public dispose(): void {
this.deactivate();
}
}

View File

@@ -0,0 +1,33 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as vscode from 'vscode';
import MainController from './controllers/mainController';
import { ApiWrapper } from './common/apiWrapper';
import { QueryRunner } from './common/queryRunner';
import { ProcessService } from './common/processService';
let controllers: MainController[] = [];
export async function activate(context: vscode.ExtensionContext): Promise<void> {
let apiWrapper = new ApiWrapper();
let queryRunner = new QueryRunner(apiWrapper);
let processService = new ProcessService();
// Start the main controller
//
let mainController = new MainController(context, apiWrapper, queryRunner, processService);
controllers.push(mainController);
context.subscriptions.push(mainController);
await mainController.activate();
}
export function deactivate(): void {
for (let controller of controllers) {
controller.deactivate();
}
}

View File

@@ -0,0 +1,156 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as nbExtensionApis from '../typings/notebookServices';
import { SqlPythonPackageManageProvider } from './sqlPackageManageProvider';
import { QueryRunner } from '../common/queryRunner';
import * as utils from '../common/utils';
import * as constants from '../common/constants';
import { ApiWrapper } from '../common/apiWrapper';
import { ProcessService } from '../common/processService';
import { Config } from '../common/config';
import { isNullOrUndefined } from 'util';
export class PackageManager {
private _pythonExecutable: string = '';
private _pythonInstallationLocation: string = '';
private _sqlPackageManager: SqlPythonPackageManageProvider | undefined = undefined;
/**
* Creates a new instance of PackageManager
*/
constructor(
private _nbExtensionApis: nbExtensionApis.IExtensionApi,
private _outputChannel: vscode.OutputChannel,
private _rootFolder: string,
private _apiWrapper: ApiWrapper,
private _queryRunner: QueryRunner,
private _processService: ProcessService,
private _config: Config) {
}
/**
* Initializes the instance and resister SQL package manager with manage package dialog
*/
public init(): void {
this._pythonInstallationLocation = utils.getPythonInstallationLocation(this._rootFolder);
this._pythonExecutable = utils.getPythonExePath(this._rootFolder);
this._sqlPackageManager = new SqlPythonPackageManageProvider(this._nbExtensionApis, this._outputChannel, this._rootFolder, this._apiWrapper, this._queryRunner, this._processService);
this._nbExtensionApis.registerPackageManager(SqlPythonPackageManageProvider.ProviderId, this._sqlPackageManager);
}
/**
* Executes manage package command for SQL server packages.
*/
public async managePackages(): Promise<void> {
// Only execute the command if there's a valid connection with ml configuration enabled
//
let connection = await this.getCurrentConnection();
let isPythonInstalled = await this._queryRunner.isPythonInstalled(connection);
if (connection && isPythonInstalled && this._sqlPackageManager) {
this._apiWrapper.executeCommand(constants.managePackagesCommand, {
multiLocations: false,
defaultLocation: this._sqlPackageManager.packageTarget.location,
defaultProviderId: SqlPythonPackageManageProvider.ProviderId
});
} else {
this._apiWrapper.showInfoMessage(constants.managePackageCommandError);
}
}
/**
* Installs dependencies for the extension
*/
public async installDependencies(): Promise<void> {
return new Promise<void>((resolve, reject) => {
let msgTaskName = constants.installDependenciesMsgTaskName;
this._apiWrapper.startBackgroundOperation({
displayName: msgTaskName,
description: msgTaskName,
isCancelable: false,
operation: async op => {
try {
if (!(await utils.exists(this._pythonExecutable))) {
// Install python
//
await utils.createFolder(this._pythonInstallationLocation);
await this.jupyterInstallation.installPythonPackage(op, false, this._pythonInstallationLocation, this._outputChannel);
}
// Install required packages
//
await this.installRequiredPythonPackages();
op.updateStatus(azdata.TaskStatus.Succeeded);
resolve();
} catch (error) {
let errorMsg = constants.installDependenciesError(error ? error.message : '');
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
reject(errorMsg);
}
}
});
});
}
/**
* Installs required python packages
*/
private async installRequiredPythonPackages(): Promise<void> {
let installedPackages = await this.getInstalledPipPackages();
let fileContent = '';
this._config.requiredPythonPackages.forEach(packageDetails => {
let hasVersion = ('version' in packageDetails) && !isNullOrUndefined(packageDetails['version']) && packageDetails['version'].length > 0;
if (!installedPackages.find(x => x.name === packageDetails['name'] && (!hasVersion || packageDetails['version'] === x.version))) {
let packageNameDetail = hasVersion ? `${packageDetails.name}==${packageDetails.version}` : `${packageDetails.name}`;
fileContent = `${fileContent}${packageNameDetail}\n`;
}
});
if (fileContent) {
this._outputChannel.appendLine(constants.installDependenciesPackages);
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
return await this.installPackages(tempFilePath);
});
this._outputChannel.appendLine(result);
} else {
this._outputChannel.appendLine(constants.installDependenciesPackagesAlreadyInstalled);
}
}
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
try {
let cmd = `"${this._pythonExecutable}" -m pip list --format=json`;
let packagesInfo = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
if (packagesInfo) {
packagesResult = <nbExtensionApis.IPackageDetails[]>JSON.parse(packagesInfo);
}
return packagesResult;
}
catch (err) {
this._outputChannel.appendLine(constants.installDependenciesGetPackagesError(err ? err.message : ''));
return [];
}
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
private get jupyterInstallation(): nbExtensionApis.IJupyterServerInstallation {
return this._nbExtensionApis.getJupyterController().jupyterInstallation;
}
private async installPackages(requirementFilePath: string): Promise<string> {
let cmd = `"${this._pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
}
}

View File

@@ -0,0 +1,183 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as nbExtensionApis from '../typings/notebookServices';
import * as utils from '../common/utils';
import * as constants from '../common/constants';
import { QueryRunner } from '../common/queryRunner';
import { ApiWrapper } from '../common/apiWrapper';
import { ProcessService } from '../common/processService';
const installMode = 'install';
const uninstallMode = 'uninstall';
const localPythonProviderId = 'localhost_Pip';
/**
* Manage Package Provider for python packages inside SQL server databases
*/
export class SqlPythonPackageManageProvider implements nbExtensionApis.IPackageManageProvider {
private _pythonExecutable: string;
public static ProviderId = 'sql_Python';
/**
* Creates new a instance
*/
constructor(
private _nbExtensionApis: nbExtensionApis.IExtensionApi,
private _outputChannel: vscode.OutputChannel,
private _rootFolder: string,
private _apiWrapper: ApiWrapper,
private _queryRunner: QueryRunner,
private _processService: ProcessService) {
this._pythonExecutable = utils.getPythonExePath(this._rootFolder);
}
/**
* Returns provider Id
*/
public get providerId(): string {
return SqlPythonPackageManageProvider.ProviderId;
}
/**
* Returns package target
*/
public get packageTarget(): nbExtensionApis.IPackageTarget {
return { location: 'SQL', packageType: 'Python' };
}
/**
* Returns list of packages
*/
public async listPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
let packages = await this._queryRunner.getPythonPackages(await this.getCurrentConnection());
if (packages) {
packages = packages.sort((a, b) => a.name.localeCompare(b.name));
} else {
packages = [];
}
return packages;
}
/**
* Installs given packages
* @param packages Packages to install
* @param useMinVersion minimum version
*/
async installPackages(packages: nbExtensionApis.IPackageDetails[], useMinVersion: boolean): Promise<void> {
if (packages) {
// TODO: install package as parallel
for (let index = 0; index < packages.length; index++) {
const element = packages[index];
await this.updatePackage(element, installMode);
}
}
//TODO: use useMinVersion
console.log(useMinVersion);
}
/**
* Execute a script to install or uninstall a python package inside current SQL Server connection
* @param packageDetails Packages to install or uninstall
* @param scriptMode can be 'install' or 'uninstall'
*/
private async updatePackage(packageDetails: nbExtensionApis.IPackageDetails, scriptMode: string): Promise<void> {
let connection = await this.getCurrentConnection();
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
if (connection) {
let port = '1433';
let server = connection.serverName;
let database = connection.databaseName ? `, database="${connection.databaseName}"` : '';
let index = connection.serverName.indexOf(',');
if (index > 0) {
port = connection.serverName.substring(index + 1);
server = connection.serverName.substring(0, index);
}
let pythonConnectionParts = `server="${server}", port=${port}, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"${database})`;
let pythonCommandScript = scriptMode === installMode ?
`pkgmanager.install(package="${packageDetails.name}", version="${packageDetails.version}")` :
`pkgmanager.uninstall(package_name="${packageDetails.name}")`;
let scripts: string[] = [
'import sqlmlutils',
`connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 17 for SQL Server", ${pythonConnectionParts}`,
'pkgmanager = sqlmlutils.SQLPackageManager(connection)',
pythonCommandScript
];
await this._processService.execScripts(this._pythonExecutable, scripts, this._outputChannel);
}
}
/**
* Uninstalls given packages
* @param packages Packages to uninstall
*/
async uninstallPackages(packages: nbExtensionApis.IPackageDetails[]): Promise<void> {
for (let index = 0; index < packages.length; index++) {
const element = packages[index];
await this.updatePackage(element, uninstallMode);
}
}
/**
* Returns true if the provider can be used
*/
async canUseProvider(): Promise<boolean> {
let connection = await this.getCurrentConnection();
if (connection && await this._queryRunner.isPythonInstalled(connection)) {
return true;
}
return false;
}
/**
* Returns package overview for given name
* @param packageName Package Name
*/
async getPackageOverview(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
let packagePreview: nbExtensionApis.IPackageOverview = {
name: packageName,
versions: [],
summary: ''
};
let pythonPackageProvider = this.pythonPackageProvider;
if (pythonPackageProvider) {
packagePreview = await pythonPackageProvider.getPackageOverview(packageName);
}
return packagePreview;
}
/**
* Returns location title
*/
async getLocationTitle(): Promise<string> {
let connection = await this.getCurrentConnection();
if (connection) {
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`;
}
return constants.packageManagerNoConnection;
}
private get pythonPackageProvider(): nbExtensionApis.IPackageManageProvider | undefined {
let providers = this._nbExtensionApis.getPackageManagers();
if (providers && providers.has(localPythonProviderId)) {
return providers.get(localPythonProviderId);
}
return undefined;
}
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
return await this._apiWrapper.getCurrentConnection();
}
}

View File

@@ -0,0 +1,48 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
const path = require('path');
const testRunner = require('vscode/lib/testrunner');
const suite = 'machine learning Extension Tests';
const options: any = {
ui: 'bdd',
useColors: true,
timeout: 600000
};
// set relevant mocha options from the environment
if (process.env.ADS_TEST_GREP) {
options.grep = process.env.ADS_TEST_GREP;
console.log(`setting options.grep to: ${options.grep}`);
}
if (process.env.ADS_TEST_INVERT_GREP) {
options.invert = parseInt(process.env.ADS_TEST_INVERT_GREP);
console.log(`setting options.invert to: ${options.invert}`);
}
if (process.env.ADS_TEST_TIMEOUT) {
options.timeout = parseInt(process.env.ADS_TEST_TIMEOUT);
console.log(`setting options.timeout to: ${options.timeout}`);
}
if (process.env.ADS_TEST_RETRIES) {
options.retries = parseInt(process.env.ADS_TEST_RETRIES);
console.log(`setting options.retries to: ${options.retries}`);
}
if (process.env.BUILD_ARTIFACTSTAGINGDIRECTORY) {
options.reporter = 'mocha-multi-reporters';
options.reporterOptions = {
reporterEnabled: 'spec, mocha-junit-reporter',
mochaJunitReporterReporterOptions: {
testsuitesTitle: `${suite} ${process.platform}`,
mochaFile: path.join(process.env.BUILD_ARTIFACTSTAGINGDIRECTORY, `test-results/${process.platform}-${suite.toLowerCase().replace(/[^\w]/g, '-')}-results.xml`)
}
};
}
testRunner.configure(options);
export = testRunner;

View File

@@ -0,0 +1,117 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as vscode from 'vscode';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import * as path from 'path';
import { ApiWrapper } from '../common/apiWrapper';
import { QueryRunner } from '../common/queryRunner';
import { ProcessService } from '../common/processService';
import MainController from '../controllers/mainController';
import { PackageManager } from '../packageManagement/packageManager';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
queryRunner: TypeMoq.IMock<QueryRunner>;
processService: TypeMoq.IMock<ProcessService>;
context: vscode.ExtensionContext;
outputChannel: vscode.OutputChannel;
extension: vscode.Extension<any>;
packageManager: TypeMoq.IMock<PackageManager>;
}
function createContext(): TestContext {
let extensionPath = path.join(__dirname, '..', '..');
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
processService: TypeMoq.Mock.ofType(ProcessService),
packageManager: TypeMoq.Mock.ofType(PackageManager),
context: {
subscriptions: [],
workspaceState: {
get: () => {return undefined;},
update: () => {return Promise.resolve();}
},
globalState: {
get: () => {return Promise.resolve();},
update: () => {return Promise.resolve();}
},
extensionPath: extensionPath,
asAbsolutePath: () => {return '';},
storagePath: '',
globalStoragePath: '',
logPath: ''
},
outputChannel: {
name: '',
append: () => { },
appendLine: () => { },
clear: () => { },
show: () => { },
hide: () => { },
dispose: () => { }
},
extension: {
id: '',
extensionPath: '',
isActive: true,
packageJSON: {},
extensionKind: vscode.ExtensionKind.UI,
exports: {},
activate: () => { return Promise.resolve(); }
}
};
}
function createController(testContext: TestContext): MainController {
let controller = new MainController(testContext.context, testContext.apiWrapper.object, testContext.queryRunner.object, testContext.processService.object);
controller.packageManager = testContext.packageManager.object;
return controller;
}
describe('Main Controller', () => {
it('Should create new instance successfully', async function (): Promise<void> {
let testContext = createContext();
testContext.apiWrapper.setup(x => x.createOutputChannel(TypeMoq.It.isAny())).returns(() => testContext.outputChannel);
should.doesNotThrow(() => createController(testContext));
});
it('initialize Should install dependencies successfully', async function (): Promise<void> {
let testContext = createContext();
testContext.apiWrapper.setup(x => x.createOutputChannel(TypeMoq.It.isAny())).returns(() => testContext.outputChannel);
testContext.apiWrapper.setup(x => x.getExtension(TypeMoq.It.isAny())).returns(() => testContext.extension);
testContext.packageManager.setup(x => x.managePackages()).returns(() => Promise.resolve());
testContext.packageManager.setup(x => x.installDependencies()).returns(() => Promise.resolve());
testContext.apiWrapper.setup(x => x.registerCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny()));
let controller = createController(testContext);
await controller.activate();
should.deepEqual(controller.config.requiredPythonPackages, [
{ name: 'pymssql', version: '2.1.4' },
{ name: 'sqlmlutils', version: '' }
]);
});
it('initialize Should show and error in output channel if installing dependencies fails', async function (): Promise<void> {
let errorReported = false;
let testContext = createContext();
testContext.apiWrapper.setup(x => x.createOutputChannel(TypeMoq.It.isAny())).returns(() => testContext.outputChannel);
testContext.apiWrapper.setup(x => x.getExtension(TypeMoq.It.isAny())).returns(() => testContext.extension);
testContext.packageManager.setup(x => x.managePackages()).returns(() => Promise.resolve());
testContext.packageManager.setup(x => x.installDependencies()).returns(() => Promise.reject());
testContext.apiWrapper.setup(x => x.registerCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny()));
testContext.outputChannel.appendLine = () => {
errorReported = true;
};
let controller = createController(testContext);
await controller.activate();
should.equal(errorReported, true);
});
});

View File

@@ -0,0 +1,223 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import { PackageManager } from '../../packageManagement/packageManager';
import { SqlPythonPackageManageProvider } from '../../packageManagement/sqlPackageManageProvider';
import { createContext, TestContext } from './utils';
describe('Package Manager', () => {
it('Should initialize SQL package manager successfully', async function (): Promise<void> {
let testContext = createContext();
should.doesNotThrow(() => createPackageManager(testContext));
should.equal(testContext.nbExtensionApis.getPackageManagers().has(SqlPythonPackageManageProvider.ProviderId), true);
});
it('Manage Package command Should execute the command for valid connection', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.queryRunner.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(true);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Manage Package command Should show an error for connection without python installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny()));
testContext.queryRunner.setup(x => x.isPythonInstalled(connection)).returns(() => {return Promise.resolve(false);});
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('Manage Package command Should show an error for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => {return Promise.resolve(connection);});
testContext.apiWrapper.setup(x => x.executeCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve();});
testContext.apiWrapper.setup(x => x.showInfoMessage(TypeMoq.It.isAny()));
let packageManager = createPackageManager(testContext);
await packageManager.managePackages();
testContext.apiWrapper.verify(x => x.showInfoMessage(TypeMoq.It.isAny()), TypeMoq.Times.once());
});
it('installDependencies Should install python if does not exist', async function (): Promise<void> {
let testContext = createContext();
let pythonInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"},
{"name":"sqlmlutils","version":"1.1.1"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.jupyterInstallation.installPythonPackage = () => {
pythonInstalled = true;
return Promise.resolve();
};
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(pythonInstalled, true);
});
it('installDependencies Should fail the task if installing python fails', async function (): Promise<void> {
let testContext = createContext();
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"},
{"name":"sqlmlutils","version":"1.1.1"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.jupyterInstallation.installPythonPackage = () => {
return Promise.reject();
};
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {return Promise.resolve(installedPackages);});
let packageManager = createPackageManager(testContext);
await should(packageManager.installDependencies()).rejected();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Failed);
});
it('installDependencies Should not install packages if already installed', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"},
{"name":"sqlmlutils","version":"1.1.1"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.jupyterInstallation.installPythonPackage = () => {
return Promise.resolve();
};
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
}
return Promise.resolve(installedPackages);
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, false);
});
it('installDependencies Should install packages that are not already installed', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.jupyterInstallation.installPythonPackage = () => {
return Promise.resolve();
};
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
}
return Promise.resolve(installedPackages);
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, true);
});
it('installDependencies Should install packages if list packages fails', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.jupyterInstallation.installPythonPackage = () => {
return Promise.resolve();
};
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command,) => {
if (command.indexOf('pip list') > 0) {
return Promise.reject();
} else if (command.indexOf('pip install') > 0) {
packagesInstalled = true;
return Promise.resolve('');
} else {
return Promise.resolve('');
}
});
let packageManager = createPackageManager(testContext);
await packageManager.installDependencies();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Succeeded);
should.equal(packagesInstalled, true);
});
it('installDependencies Should fail if install packages fails', async function (): Promise<void> {
let testContext = createContext();
let packagesInstalled = false;
let installedPackages = `[
{"name":"pymssql","version":"2.1.4"}
]`;
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
operationInfo.operation(testContext.op);
});
testContext.jupyterInstallation.installPythonPackage = () => {
return Promise.resolve();
};
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((command) => {
if (command.indexOf('pip list') > 0) {
return Promise.resolve(installedPackages);
} else if (command.indexOf('pip install') > 0) {
return Promise.reject();
} else {
return Promise.resolve('');
}
});
let packageManager = createPackageManager(testContext);
await should(packageManager.installDependencies()).rejected();
should.equal(testContext.getOpStatus(), azdata.TaskStatus.Failed);
should.equal(packagesInstalled, false);
});
function createPackageManager(testContext: TestContext): PackageManager {
testContext.config.setup(x => x.requiredPythonPackages).returns( () => [
{ name: 'pymssql', version: '2.1.4' },
{ name: 'sqlmlutils', version: '' }
]);
let packageManager = new PackageManager(
testContext.nbExtensionApis,
testContext.outputChannel,
'',
testContext.apiWrapper.object,
testContext.queryRunner.object,
testContext.processService.object,
testContext.config.object);
packageManager.init();
return packageManager;
}
});

View File

@@ -0,0 +1,373 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
import * as constants from '../../common/constants';
import { SqlPythonPackageManageProvider } from '../../packageManagement/sqlPackageManageProvider';
import { createContext, TestContext } from './utils';
import * as nbExtensionApis from '../../typings/notebookServices';
describe('SQL Package Manager', () => {
it('Should create SQL package manager successfully', async function (): Promise<void> {
let testContext = createContext();
should.doesNotThrow(() => createProvider(testContext));
});
it('Should return provider Id and target correctly', async function (): Promise<void> {
let testContext = createContext();
let provider = createProvider(testContext);
should.deepEqual(SqlPythonPackageManageProvider.ProviderId, provider.providerId);
should.deepEqual({ location: 'SQL', packageType: 'Python' }, provider.packageTarget);
});
it('listPackages Should return packages sorted by name', async function (): Promise<void> {
let testContext = createContext();
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'b-name',
'version': '1.1.1'
},
{
'name': 'a-name',
'version': '1.1.2'
}
];
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.queryRunner.setup(x => x.getPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
let provider = createProvider(testContext);
let actual = await provider.listPackages();
let expected = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
should.deepEqual(actual, expected);
});
it('listPackages Should return empty packages if undefined packages returned', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
let packages: nbExtensionApis.IPackageDetails[];
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.queryRunner.setup(x => x.getPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve(packages));
let provider = createProvider(testContext);
let actual = await provider.listPackages();
let expected: nbExtensionApis.IPackageDetails[] = [];
should.deepEqual(actual, expected);
});
it('listPackages Should return empty packages if empty packages returned', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.queryRunner.setup(x => x.getPythonPackages(TypeMoq.It.isAny())).returns(() => Promise.resolve([]));
let provider = createProvider(testContext);
let actual = await provider.listPackages();
let expected: nbExtensionApis.IPackageDetails[] = [];
should.deepEqual(actual, expected);
});
it('installPackages Should install given packages successfully', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('install') > 0) &&
scripts.find(x => x.indexOf('port=1433') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('package="a-name"') > 0) &&
scripts.find(x => x.indexOf('version="1.1.2"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve();
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false);
should.deepEqual(packagesUpdated, true);
});
it('uninstallPackages Should uninstall given packages successfully', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('uninstall') > 0) &&
scripts.find(x => x.indexOf('port=1433') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('package_name="a-name"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve();
});
let provider = createProvider(testContext);
await provider.uninstallPackages(packages);
should.deepEqual(packagesUpdated, true);
});
it('installPackages Should include port name in the script', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
{
'name': 'a-name',
'version': '1.1.2'
},
{
'name': 'b-name',
'version': '1.1.1'
}
];
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName,3433';
connection.databaseName = 'databaseName';
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns((path, scripts: string[]) => {
if (path && scripts.find(x => x.indexOf('install') > 0) &&
scripts.find(x => x.indexOf('port=3433') > 0) &&
scripts.find(x => x.indexOf('server="serverName"') > 0) &&
scripts.find(x => x.indexOf('database="databaseName"') > 0) &&
scripts.find(x => x.indexOf('package="a-name"') > 0) &&
scripts.find(x => x.indexOf('version="1.1.2"') > 0) &&
scripts.find(x => x.indexOf('pwd="password"') > 0)) {
packagesUpdated = true;
}
return Promise.resolve();
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false);
should.deepEqual(packagesUpdated, true);
});
it('installPackages Should not install any packages give empty list', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
];
let connection = new azdata.connection.ConnectionProfile();
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
packagesUpdated = true;
return Promise.resolve();
});
let provider = createProvider(testContext);
await provider.installPackages(packages, false);
should.deepEqual(packagesUpdated, false);
});
it('uninstallPackages Should not uninstall any packages give empty list', async function (): Promise<void> {
let testContext = createContext();
let packagesUpdated = false;
let packages: nbExtensionApis.IPackageDetails[] = [
];
let connection = new azdata.connection.ConnectionProfile();
let credentials = { ['azdata.ConnectionOptionSpecialType.password']: 'password' };
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
packagesUpdated = true;
return Promise.resolve();
});
let provider = createProvider(testContext);
await provider.uninstallPackages(packages);
should.deepEqual(packagesUpdated, false);
});
it('canUseProvider Should return false for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('canUseProvider Should return false if connection does not have python installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.queryRunner.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(false));
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, false);
});
it('canUseProvider Should return true if connection has python installed', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
testContext.queryRunner.setup(x => x.isPythonInstalled(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
let provider = createProvider(testContext);
let actual = await provider.canUseProvider();
should.deepEqual(actual, true);
});
it('getPackageOverview Should not return undefined if python package provider not found', async function (): Promise<void> {
let testContext = createContext();
let provider = createProvider(testContext);
let actual = await provider.getPackageOverview('package name');
should.notEqual(actual, undefined);
});
it('getPackageOverview Should return package info using python packages provider', async function (): Promise<void> {
let testContext = createContext();
let packagePreview = {
'name': 'a-name',
'versions': ['1.1.2'],
'summary': ''
};
let pythonPackageManager: nbExtensionApis.IPackageManageProvider = {
providerId: 'localhost_Pip',
packageTarget: { location: '', packageType: '' },
listPackages: () => { return Promise.resolve([]); },
installPackages: () => { return Promise.resolve(); },
uninstallPackages: () => { return Promise.resolve(); },
canUseProvider: () => { return Promise.resolve(true); },
getLocationTitle: () => { return Promise.resolve(''); },
getPackageOverview: () => { return Promise.resolve(packagePreview); }
};
testContext.nbExtensionApis.registerPackageManager(pythonPackageManager.providerId, pythonPackageManager);
let provider = createProvider(testContext);
let actual = await provider.getPackageOverview('package name');
should.deepEqual(actual, packagePreview);
});
it('getLocationTitle Should default string for no connection', async function (): Promise<void> {
let testContext = createContext();
let connection: azdata.connection.ConnectionProfile;
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.getLocationTitle();
should.deepEqual(actual, constants.packageManagerNoConnection);
});
it('getLocationTitle Should return connection title string for valid connection', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
connection.databaseName = 'databaseName';
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.getLocationTitle();
should.deepEqual(actual, `${connection.serverName} ${connection.databaseName}`);
});
it('getLocationTitle Should return server name as connection title if there is not database name', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
connection.serverName = 'serverName';
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
let provider = createProvider(testContext);
let actual = await provider.getLocationTitle();
should.deepEqual(actual, `${connection.serverName} `);
});
function createProvider(testContext: TestContext): SqlPythonPackageManageProvider {
return new SqlPythonPackageManageProvider(
testContext.nbExtensionApis,
testContext.outputChannel,
'',
testContext.apiWrapper.object,
testContext.queryRunner.object,
testContext.processService.object);
}
});

View File

@@ -0,0 +1,83 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as nbExtensionApis from '../../typings/notebookServices';
import * as TypeMoq from 'typemoq';
import { ApiWrapper } from '../../common/apiWrapper';
import { QueryRunner } from '../../common/queryRunner';
import { ProcessService } from '../../common/processService';
import { Config } from '../../common/config';
export interface TestContext {
jupyterInstallation: nbExtensionApis.IJupyterServerInstallation;
jupyterController: nbExtensionApis.IJupyterController;
nbExtensionApis: nbExtensionApis.IExtensionApi;
outputChannel: vscode.OutputChannel;
processService: TypeMoq.IMock<ProcessService>;
apiWrapper: TypeMoq.IMock<ApiWrapper>;
queryRunner: TypeMoq.IMock<QueryRunner>;
config: TypeMoq.IMock<Config>;
op: azdata.BackgroundOperation;
getOpStatus: () => azdata.TaskStatus;
}
export function createContext(): TestContext {
let opStatus: azdata.TaskStatus;
let packages = new Map<string, nbExtensionApis.IPackageManageProvider>();
let jupyterInstallation: nbExtensionApis.IJupyterServerInstallation = {
installCondaPackages: () => { return Promise.resolve(); },
getInstalledPipPackages: () => { return Promise.resolve([]); },
installPipPackages: () => { return Promise.resolve(); },
uninstallPipPackages: () => { return Promise.resolve(); },
uninstallCondaPackages: () => { return Promise.resolve(); },
executeBufferedCommand: () => { return Promise.resolve(''); },
executeStreamedCommand: () => { return Promise.resolve(); },
pythonExecutable: '',
pythonInstallationPath: '',
installPythonPackage: () => { return Promise.resolve(); }
};
let jupyterController = {
jupyterInstallation: jupyterInstallation
};
return {
jupyterInstallation: jupyterInstallation,
jupyterController: jupyterController,
nbExtensionApis: {
getJupyterController: () => { return jupyterController; },
registerPackageManager: (providerId: string, packageManagerProvider: nbExtensionApis.IPackageManageProvider) => {
packages.set(providerId, packageManagerProvider);
},
getPackageManagers: () => { return packages; },
},
outputChannel: {
name: '',
append: () => { },
appendLine: () => { },
clear: () => { },
show: () => { },
hide: () => { },
dispose: () => { }
},
processService: TypeMoq.Mock.ofType(ProcessService),
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryRunner: TypeMoq.Mock.ofType(QueryRunner),
config: TypeMoq.Mock.ofType(Config),
op: {
updateStatus: (status: azdata.TaskStatus) => {
opStatus = status;
},
id: '',
onCanceled: new vscode.EventEmitter<void>().event,
},
getOpStatus: () => { return opStatus; }
};
}

View File

@@ -0,0 +1,305 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
import * as azdata from 'azdata';
import { ApiWrapper } from '../common/apiWrapper';
import * as TypeMoq from 'typemoq';
import * as should from 'should';
import { QueryRunner } from '../common/queryRunner';
import { IPackageDetails } from '../typings/notebookServices';
interface TestContext {
apiWrapper: TypeMoq.IMock<ApiWrapper>;
queryProvider: azdata.QueryProvider;
}
function createContext(): TestContext {
return {
apiWrapper: TypeMoq.Mock.ofType(ApiWrapper),
queryProvider: {
providerId: '',
cancelQuery: () => {return Promise.reject();},
runQuery: () => {return Promise.reject();},
runQueryStatement: () => {return Promise.reject();},
runQueryString: () => {return Promise.reject();},
runQueryAndReturn: () => { return Promise.reject(); },
parseSyntax: () => {return Promise.reject();},
getQueryRows: () => {return Promise.reject();},
disposeQuery: () => {return Promise.reject();},
saveResults: () => {return Promise.reject();},
setQueryExecutionOptions: () => {return Promise.reject();},
registerOnQueryComplete: () => {return Promise.reject();},
registerOnBatchStart: () => {return Promise.reject();},
registerOnBatchComplete: () => {return Promise.reject();},
registerOnResultSetAvailable: () => {return Promise.reject();},
registerOnResultSetUpdated: () => {return Promise.reject();},
registerOnMessage: () => {return Promise.reject();},
commitEdit: () => {return Promise.reject();},
createRow: () => {return Promise.reject();},
deleteRow: () => {return Promise.reject();},
disposeEdit: () => {return Promise.reject();},
initializeEdit: () => {return Promise.reject();},
revertCell: () => {return Promise.reject();},
revertRow: () => {return Promise.reject();},
updateCell: () => {return Promise.reject();},
getEditRows: () => {return Promise.reject();},
registerOnEditSessionReady: () => {return Promise.reject();},
}
};
}
describe('Query Runner', () => {
it('getPythonPackages Should return empty list if not provider found', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
let queryProvider: azdata.QueryProvider;
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => queryProvider);
let actual = await queryRunner.getPythonPackages(connection);
should.deepEqual(actual, []);
});
it('getPythonPackages Should return empty list if not provider throws', async function (): Promise<void> {
let testContext = createContext();
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.reject(); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.getPythonPackages(connection);
should.deepEqual(actual, []);
});
it('getPythonPackages Should return list if provider runs the query successfully', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: 'p1',
isNull: false,
invariantCultureDisplayValue: ''
}, {
displayValue: '1.1.1',
isNull: false,
invariantCultureDisplayValue: ''
}],
[{
displayValue: 'p2',
isNull: false,
invariantCultureDisplayValue: ''
}, {
displayValue: '1.1.2',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = [
{
'name': 'p1',
'version': '1.1.1'
},
{
'name': 'p2',
'version': '1.1.2'
}
];
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.getPythonPackages(connection);
should.deepEqual(actual, expected);
});
it('getPythonPackages Should return empty list if provider return no rows', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
];
let expected: IPackageDetails[] = [];
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.getPythonPackages(connection);
should.deepEqual(actual, expected);
});
it('updateExternalScriptConfig Should update config successfully', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
];
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
await should(queryRunner.updateExternalScriptConfig(connection, true)).resolved();
});
it('isPythonInstalled Should return true is provider returns valid result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = true;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isPythonInstalled(connection);
should.deepEqual(actual, expected);
});
it('isPythonInstalled Should return true is provider returns 0 as result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '0',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isPythonInstalled(connection);
should.deepEqual(actual, expected);
});
it('isPythonInstalled Should return false is provider returns no result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isPythonInstalled(connection);
should.deepEqual(actual, expected);
});
it('isMachineLearningServiceEnabled Should return true is provider returns valid result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '1',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = true;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
should.deepEqual(actual, expected);
});
it('isMachineLearningServiceEnabled Should return true is provider returns 0 as result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [
[{
displayValue: '0',
isNull: false,
invariantCultureDisplayValue: ''
}]
];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
should.deepEqual(actual, expected);
});
it('isMachineLearningServiceEnabled Should return false is provider returns no result', async function (): Promise<void> {
let testContext = createContext();
let rows: azdata.DbCellValue[][] = [];
let expected = false;
let result : azdata.SimpleExecuteResult = {
rowCount: 2,
columnInfo: [],
rows: rows,
};
let connection = new azdata.connection.ConnectionProfile();
let queryRunner = new QueryRunner(testContext.apiWrapper.object);
testContext.queryProvider.runQueryAndReturn = () => { return Promise.resolve(result); };
testContext.apiWrapper.setup(x => x.getProvider<azdata.QueryProvider>(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => testContext.queryProvider);
let actual = await queryRunner.isMachineLearningServiceEnabled(connection);
should.deepEqual(actual, expected);
});
});

View File

@@ -0,0 +1,23 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
'use strict';
const _typeof = {
undefined: 'undefined'
};
/**
* @returns whether the provided parameter is undefined or null.
*/
export function isUndefinedOrNull(obj: any): boolean {
return isUndefined(obj) || obj === null;
}
/**
* @returns whether the provided parameter is undefined.
*/
export function isUndefined(obj: any): boolean {
return typeof (obj) === _typeof.undefined;
}

View File

@@ -0,0 +1,63 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as vscode from 'vscode';
import * as azdata from 'azdata';
/**
* The API provided by this extension.
*
* @export
*/
export interface IExtensionApi {
getJupyterController(): IJupyterController;
registerPackageManager(providerId: string, packageManagerProvider: IPackageManageProvider): void
getPackageManagers(): Map<string, IPackageManageProvider>
}
export interface IJupyterController {
jupyterInstallation: IJupyterServerInstallation;
}
export interface IJupyterServerInstallation {
installPipPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise<void>;
uninstallPipPackages(packages: IPackageDetails[]): Promise<void>;
installCondaPackages(packages: IPackageDetails[], useMinVersion: boolean): Promise<void>;
uninstallCondaPackages(packages: IPackageDetails[]): Promise<void>;
getInstalledPipPackages(): Promise<IPackageDetails[]>;
pythonExecutable: string;
pythonInstallationPath: string;
executeBufferedCommand(command: string): Promise<string>;
executeStreamedCommand(command: string): Promise<void>;
installPythonPackage(backgroundOperation: azdata.BackgroundOperation, usingExistingPython: boolean, pythonInstallationPath: string, outputChannel: vscode.OutputChannel): Promise<void>;
}
export interface IPackageDetails {
name: string;
version: string;
}
export interface IPackageTarget {
location: string;
packageType: string;
}
export interface IPackageOverview {
name: string;
versions: string[];
summary: string;
}
export interface IPackageManageProvider {
providerId: string;
packageTarget: IPackageTarget;
listPackages(): Promise<IPackageDetails[]>
installPackages(package: IPackageDetails[], useMinVersion: boolean): Promise<void>;
uninstallPackages(package: IPackageDetails[]): Promise<void>;
canUseProvider(): Promise<boolean>;
getLocationTitle(): Promise<string>;
getPackageOverview(packageName: string): Promise<IPackageOverview>
}

View File

@@ -0,0 +1,9 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
/// <reference path='../../../../src/vs/vscode.d.ts'/>
/// <reference path='../../../../src/sql/azdata.d.ts'/>
/// <reference path='../../../../src/sql/azdata.proposed.d.ts'/>
/// <reference types='@types/node'/>

View File

@@ -0,0 +1,24 @@
{
"extends": "../shared.tsconfig.json",
"compileOnSave": true,
"compilerOptions": {
"module": "commonjs",
"target": "es6",
"outDir": "./out",
"lib": [
"es6",
"es2015.promise"
],
"sourceMap": true,
"emitDecoratorMetadata": true,
"experimentalDecorators": true,
"moduleResolution": "node",
"declaration": false
},
"include": [
"src/**/*"
],
"exclude": [
"node_modules"
]
}

File diff suppressed because it is too large Load Diff

View File

@@ -52,6 +52,8 @@ export interface IJupyterServerInstallation {
installPipPackages(packages: PythonPkgDetails[], useMinVersion: boolean): Promise<void>;
uninstallPipPackages(packages: PythonPkgDetails[]): Promise<void>;
pythonExecutable: string;
pythonInstallationPath: string;
installPythonPackage(backgroundOperation: azdata.BackgroundOperation, usingExistingPython: boolean, pythonInstallationPath: string, outputChannel: OutputChannel): Promise<void>;
}
export class JupyterServerInstallation implements IJupyterServerInstallation {
public apiWrapper: ApiWrapper;
@@ -128,7 +130,7 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, msgInstallPkgProgress);
try {
await this.installPythonPackage(backgroundOperation);
await this.installPythonPackage(backgroundOperation, this._usingExistingPython, this._pythonInstallationPath, this.outputChannel);
if (this._usingExistingPython) {
await this.upgradePythonPackages(false, forceInstall);
@@ -146,8 +148,8 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
}
}
private installPythonPackage(backgroundOperation: azdata.BackgroundOperation): Promise<void> {
if (this._usingExistingPython) {
public installPythonPackage(backgroundOperation: azdata.BackgroundOperation, usingExistingPython: boolean, pythonInstallationPath: string, outputChannel: OutputChannel): Promise<void> {
if (usingExistingPython) {
return Promise.resolve();
}
@@ -174,7 +176,7 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
}
return new Promise((resolve, reject) => {
let installPath = this._pythonInstallationPath;
let installPath = pythonInstallationPath;
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, msgDownloadPython(platformId, pythonDownloadUrl));
fs.mkdirs(installPath, (err) => {
if (err) {
@@ -198,7 +200,7 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
let totalBytes = parseInt(response.headers['content-length']);
totalMegaBytes = totalBytes / (1024 * 1024);
this.outputChannel.appendLine(`${msgPythonDownloadPending} (0 / ${totalMegaBytes.toFixed(2)} MB)`);
outputChannel.appendLine(`${msgPythonDownloadPending} (0 / ${totalMegaBytes.toFixed(2)} MB)`);
})
.on('data', (data) => {
receivedBytes += data.length;
@@ -206,7 +208,7 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
let receivedMegaBytes = receivedBytes / (1024 * 1024);
let percentage = receivedMegaBytes / totalMegaBytes;
if (percentage >= printThreshold) {
this.outputChannel.appendLine(`${msgPythonDownloadPending} (${receivedMegaBytes.toFixed(2)} / ${totalMegaBytes.toFixed(2)} MB)`);
outputChannel.appendLine(`${msgPythonDownloadPending} (${receivedMegaBytes.toFixed(2)} / ${totalMegaBytes.toFixed(2)} MB)`);
printThreshold += 0.1;
}
}
@@ -216,7 +218,7 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
downloadRequest.pipe(fs.createWriteStream(pythonPackagePathLocal))
.on('close', async () => {
//unpack python zip/tar file
this.outputChannel.appendLine(msgPythonUnpackPending);
outputChannel.appendLine(msgPythonUnpackPending);
let pythonSourcePath = path.join(installPath, constants.pythonBundleVersion);
if (await utils.exists(pythonSourcePath)) {
try {
@@ -235,7 +237,7 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
}
});
this.outputChannel.appendLine(msgPythonDownloadComplete);
outputChannel.appendLine(msgPythonDownloadComplete);
backgroundOperation.updateStatus(azdata.TaskStatus.InProgress, msgPythonDownloadComplete);
resolve();
}).catch(err => {
@@ -657,6 +659,13 @@ export class JupyterServerInstallation implements IJupyterServerInstallation {
process.platform === constants.winPlatform ? 'Scripts\\conda.exe' : 'bin/conda');
}
/**
* Returns Python installation path
*/
public get pythonInstallationPath(): string {
return this._pythonInstallationPath;
}
public get usingConda(): boolean {
return this._usingConda;
}

View File

@@ -5,6 +5,8 @@
'use strict';
import * as vscode from 'vscode';
import * as azdata from 'azdata';
import * as should from 'should';
import 'mocha';
import * as TypeMoq from 'typemoq';
@@ -200,7 +202,9 @@ describe('Manage Package Providers', () => {
executeStreamedCommand: (command: string) => { return Promise.resolve(); },
getCondaExePath: () => { return ''; },
pythonExecutable: '',
usingConda: false
pythonInstallationPath: '',
usingConda: false,
installPythonPackage: (backgroundOperation: azdata.BackgroundOperation, usingExistingPython: boolean, pythonInstallationPath: string, outputChannel: vscode.OutputChannel) => {return Promise.resolve(); }
},
piPyClient: {
fetchPypiPackage: (packageName) => { return Promise.resolve(); }

View File

@@ -78,6 +78,11 @@ echo *** starting resource deployment tests ***
echo ******************************************
call "%INTEGRATION_TEST_ELECTRON_PATH%" --nogpu --extensionDevelopmentPath=%~dp0\..\extensions\resource-deployment --extensionTestsPath=%~dp0\..\extensions\resource-deployment\out\test --user-data-dir=%VSCODEUSERDATADIR% --extensions-dir=%VSCODEEXTENSIONSDIR% --remote-debugging-port=9222
echo *******************************
echo *** starting machine-learning-services tests ***
echo *******************************
call .\scripts\code.bat --extensionDevelopmentPath=%~dp0\..\extensions\machine-learning-services --extensionTestsPath=%~dp0\..\extensions\machine-learning-services\out\test --user-data-dir=%VSCODEUSERDATADIR% --extensions-dir=%VSCODEEXTENSIONSDIR% --remote-debugging-port=9222
if %errorlevel% neq 0 exit /b %errorlevel%
rmdir /s /q %VSCODEUSERDATADIR%

View File

@@ -79,5 +79,10 @@ echo *** starting resource deployment tests ***
echo ******************************************
"$INTEGRATION_TEST_ELECTRON_PATH" $LINUX_NO_SANDBOX --nogpu --extensionDevelopmentPath=$ROOT/extensions/resource-deployment --extensionTestsPath=$ROOT/extensions/resource-deployment/out/test --user-data-dir=$VSCODEUSERDATADIR --extensions-dir=$VSCODEEXTDIR
echo ******************************************
echo *** starting machine-learning-services tests ***
echo ******************************************
"$INTEGRATION_TEST_ELECTRON_PATH" $LINUX_NO_SANDBOX --nogpu --extensionDevelopmentPath=$ROOT/extensions/machine-learning-services --extensionTestsPath=$ROOT/extensions/machine-learning-services/out/test --user-data-dir=$VSCODEUSERDATADIR --extensions-dir=$VSCODEEXTDIR
rm -r $VSCODEUSERDATADIR
rm -r $VSCODEEXTDIR