diff --git a/extensions/machine-learning/src/common/constants.ts b/extensions/machine-learning/src/common/constants.ts index 71f6d153d4..64531f9231 100644 --- a/extensions/machine-learning/src/common/constants.ts +++ b/extensions/machine-learning/src/common/constants.ts @@ -13,6 +13,7 @@ export const managePackagesCommand = 'jupyter.cmd.managePackages'; export const pythonLanguageName = 'Python'; export const rLanguageName = 'R'; export const rLPackagedFolderName = 'r_packages'; +export const supportedODBCDriver = 'ODBC Driver 17 for SQL Server'; export const mlEnableMlsCommand = 'ml.command.enableMls'; export const mlDisableMlsCommand = 'ml.command.disableMls'; @@ -48,6 +49,7 @@ export const adsPythonBundleVersion = '0.0.1'; export const msgYes = localize('msgYes', "Yes"); export const msgNo = localize('msgNo', "No"); export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed."); +export const verifyOdbcDriverError = localize('mls.verifyOdbcDriverError.error', "'{0}' is required for package management. Please make sure it is installed and set up correctly.", supportedODBCDriver); export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); } export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); } export function cannotFindR(path: string): string { return localize('mls.cannotFindR.error', "Cannot find R executable '{0}'. Please make sure R is installed and configured correctly", path); } diff --git a/extensions/machine-learning/src/common/utils.ts b/extensions/machine-learning/src/common/utils.ts index 76bfe244c6..79b5132c79 100644 --- a/extensions/machine-learning/src/common/utils.ts +++ b/extensions/machine-learning/src/common/utils.ts @@ -283,3 +283,34 @@ export function getPythonExeName(): string { export function getUserHome(): string | undefined { return process.env.HOME || process.env.USERPROFILE; } + +export function getKeyValueString(key: string, value: string, separator: string = '='): string { + return `${key}${separator}${value}`; +} + +export function getServerPort(connection: azdata.connection.ConnectionProfile): string { + if (!connection) { + return ''; + } + let index = connection.serverName.indexOf(','); + if (index > 0) { + return connection.serverName.substring(index + 1); + } else { + return '1433'; + } +} + +export function getServerName(connection: azdata.connection.ConnectionProfile): string { + if (!connection) { + return ''; + } + let index = connection.serverName.indexOf(','); + if (index > 0) { + return connection.serverName.substring(0, index); + } else { + return connection.serverName; + } +} + + + diff --git a/extensions/machine-learning/src/packageManagement/packageManager.ts b/extensions/machine-learning/src/packageManagement/packageManager.ts index 42bdfcd48b..60df14debb 100644 --- a/extensions/machine-learning/src/packageManagement/packageManager.ts +++ b/extensions/machine-learning/src/packageManagement/packageManager.ts @@ -120,6 +120,8 @@ export class PackageManager { await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [ this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages), this.installRequiredRPackages()], true); + + await this.verifyOdbcInstalled(); } private async installRequiredRPackages(): Promise { @@ -186,6 +188,45 @@ export class PackageManager { } } + private async verifyOdbcInstalled(): Promise { + let connection = await this.getCurrentConnection(); + if (connection) { + let credentials = await this._apiWrapper.getCredentials(connection.connectionId); + const separator = '='; + let connectionParts: string[] = []; + if (connection) { + connectionParts.push(utils.getKeyValueString('DRIVER', `{${constants.supportedODBCDriver}}`, separator)); + + if (connection.userName) { + connectionParts.push(utils.getKeyValueString('UID', connection.userName, separator)); + connectionParts.push(utils.getKeyValueString('PWD', credentials[azdata.ConnectionOptionSpecialType.password], separator)); + } else { + connectionParts.push(utils.getKeyValueString('Trusted_Connection', 'yes', separator)); + + } + + connectionParts.push(utils.getKeyValueString('SERVER', connection.serverName, separator)); + } + + let scripts: string[] = [ + 'import pyodbc', + `connection = pyodbc.connect('${connectionParts.join(';')}')`, + 'cursor = connection.cursor()', + 'cursor.execute("SELECT @@version;")' + ]; + let pythonExecutable = await this._config.getPythonExecutable(true); + try { + await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel); + } catch (err) { + const result = await this._apiWrapper.showErrorMessage(constants.verifyOdbcDriverError, constants.learnMoreTitle); + if (result === constants.learnMoreTitle) { + await this._apiWrapper.openExternal(vscode.Uri.parse(constants.odbcDriverDocuments)); + } + throw err; + } + } + } + private async getInstalledPipPackages(): Promise { try { let pythonExecutable = await this.getPythonExecutable(); diff --git a/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts b/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts index 9a8c5d8bbe..af7aa1c62f 100644 --- a/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts +++ b/extensions/machine-learning/src/packageManagement/sqlPythonPackageManageProvider.ts @@ -13,6 +13,7 @@ import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProvide import { HttpClient } from '../common/httpClient'; import * as utils from '../common/utils'; import { PackageManagementService } from './packageManagementService'; +import * as constants from '../common/constants'; /** * Manage Package Provider for python packages inside SQL server databases @@ -62,26 +63,31 @@ export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise { let connection = await this.getCurrentConnection(); let credentials = await this._apiWrapper.getCredentials(connection.connectionId); + let connectionParts: string[] = []; if (connection) { - let port = '1433'; - let server = connection.serverName; - let database = databaseName ? `, database="${databaseName}"` : ''; - const auth = connection.userName ? `, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"` : ''; - let index = connection.serverName.indexOf(','); - if (index > 0) { - port = connection.serverName.substring(index + 1); - server = connection.serverName.substring(0, index); + connectionParts.push(utils.getKeyValueString('driver', `"${constants.supportedODBCDriver}"`)); + + let port = utils.getServerPort(connection); + let server = utils.getServerName(connection); + if (databaseName) { + connectionParts.push(utils.getKeyValueString('database', `"${databaseName}"`)); + } + if (connection.userName) { + connectionParts.push(utils.getKeyValueString('uid', `"${connection.userName}"`)); + connectionParts.push(utils.getKeyValueString('pwd', `"${credentials[azdata.ConnectionOptionSpecialType.password]}"`)); } - let pythonConnectionParts = `server="${server}", port=${port}${auth}${database})`; + connectionParts.push(utils.getKeyValueString('server', `"${server}"`)); + connectionParts.push(utils.getKeyValueString('port', port)); + let pythonCommandScript = scriptMode === ScriptMode.Install ? `pkgmanager.install(package="${packageDetails.name}", version="${packageDetails.version}")` : `pkgmanager.uninstall(package_name="${packageDetails.name}")`; let scripts: string[] = [ 'import sqlmlutils', - `connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 17 for SQL Server", ${pythonConnectionParts}`, + `connection = sqlmlutils.ConnectionInfo(${connectionParts.join(',')})`, 'pkgmanager = sqlmlutils.SQLPackageManager(connection)', pythonCommandScript ]; diff --git a/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts b/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts index 6e2f163632..8ac9d896d9 100644 --- a/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts +++ b/extensions/machine-learning/src/packageManagement/sqlRPackageManageProvider.ts @@ -14,7 +14,7 @@ import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProvide import { HttpClient } from '../common/httpClient'; import * as constants from '../common/constants'; import { PackageManagementService } from './packageManagementService'; - +import * as utils from '../common/utils'; /** @@ -66,18 +66,26 @@ export class SqlRPackageManageProvider extends SqlPackageManageProviderBase impl protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise { let connection = await this.getCurrentConnection(); let credentials = await this._apiWrapper.getCredentials(connection.connectionId); + let connectionParts: string[] = []; if (connection) { + connectionParts.push(utils.getKeyValueString('driver', constants.supportedODBCDriver)); let server = connection.serverName.replace('\\', '\\\\'); - let database = databaseName ? `, database="${databaseName}"` : ''; - const auth = connection.userName ? `, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"` : ''; - let connectionParts = `server="${server}"${auth}${database}`; + if (databaseName) { + connectionParts.push(utils.getKeyValueString('database', `"${databaseName}"`)); + } + if (connection.userName) { + connectionParts.push(utils.getKeyValueString('uid', `"${connection.userName}"`)); + connectionParts.push(utils.getKeyValueString('pwd', `"${credentials[azdata.ConnectionOptionSpecialType.password]}"`)); + } + connectionParts.push(utils.getKeyValueString('server', `"${server}"`)); + let rCommandScript = scriptMode === ScriptMode.Install ? 'sql_install.packages' : 'sql_remove.packages'; let scripts: string[] = [ 'formals(quit)$save <- formals(q)$save <- "no"', 'library(sqlmlutils)', - `connection <- connectionInfo(driver= "ODBC Driver 17 for SQL Server", ${connectionParts})`, + `connection <- connectionInfo(${connectionParts.join(', ')})`, `r = getOption("repos")`, `r["CRAN"] = "${this._config.rPackagesRepository}"`, `options(repos = r)`, diff --git a/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts b/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts index 2cfd1dedaf..7e7edbc02f 100644 --- a/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts +++ b/extensions/machine-learning/src/test/packageManagement/packageManager.test.ts @@ -10,6 +10,7 @@ import 'mocha'; import * as TypeMoq from 'typemoq'; import { PackageManager } from '../../packageManagement/packageManager'; import { createContext, TestContext } from './utils'; +import * as constants from '../../common/constants'; describe('Package Manager', () => { it('Should initialize SQL package manager successfully', async function (): Promise { @@ -114,6 +115,56 @@ describe('Package Manager', () => { should.equal(packagesInstalled, false); }); + it('installDependencies Should fail if odbc not installed', async function (): Promise { + let testContext = createContext(); + let installedPackages = `[ + {"name":"pymssql","version":"2.1.4"}, + {"name":"sqlmlutils","version":"1.1.1"} + ]`; + let connection = new azdata.connection.ConnectionProfile(); + let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' }; + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); }); + testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { + operationInfo.operation(testContext.op); + }); + testContext.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve('')); + testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return Promise.resolve(installedPackages); + }); + + testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject('error')); + + let packageManager = createPackageManager(testContext); + await should(packageManager.installDependencies()).be.rejected(); + }); + + it('installDependencies should open link for odbc document if user selects the link', async function (): Promise { + let testContext = createContext(); + let connection = new azdata.connection.ConnectionProfile(); + let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' }; + testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); }); + testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); }); + let installedPackages = `[ + {"name":"pymssql","version":"2.1.4"}, + {"name":"sqlmlutils","version":"1.1.1"} + ]`; + testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => { + operationInfo.operation(testContext.op); + }); + testContext.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(constants.learnMoreTitle)); + testContext.apiWrapper.setup(x => x.openExternal(TypeMoq.It.isAny())).returns(() => Promise.resolve(true)); + testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => { + return Promise.resolve(installedPackages); + }); + + testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject('error')); + + let packageManager = createPackageManager(testContext); + await should(packageManager.installDependencies()).be.rejected(); + testContext.apiWrapper.verify(x => x.openExternal(TypeMoq.It.isAny()), TypeMoq.Times.atMostOnce()); + }); + it('installDependencies Should install packages that are not already installed', async function (): Promise { let testContext = createContext(); let packagesInstalled = false;