mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
ML - Verifying ODBC Driver (#11587)
* Verifying ODBC is installed before opening package manager dialog
This commit is contained in:
@@ -13,6 +13,7 @@ export const managePackagesCommand = 'jupyter.cmd.managePackages';
|
||||
export const pythonLanguageName = 'Python';
|
||||
export const rLanguageName = 'R';
|
||||
export const rLPackagedFolderName = 'r_packages';
|
||||
export const supportedODBCDriver = 'ODBC Driver 17 for SQL Server';
|
||||
|
||||
export const mlEnableMlsCommand = 'ml.command.enableMls';
|
||||
export const mlDisableMlsCommand = 'ml.command.disableMls';
|
||||
@@ -48,6 +49,7 @@ export const adsPythonBundleVersion = '0.0.1';
|
||||
export const msgYes = localize('msgYes', "Yes");
|
||||
export const msgNo = localize('msgNo', "No");
|
||||
export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed.");
|
||||
export const verifyOdbcDriverError = localize('mls.verifyOdbcDriverError.error', "'{0}' is required for package management. Please make sure it is installed and set up correctly.", supportedODBCDriver);
|
||||
export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); }
|
||||
export function cannotFindPython(path: string): string { return localize('mls.cannotFindPython.error', "Cannot find Python executable '{0}'. Please make sure Python is installed and configured correctly", path); }
|
||||
export function cannotFindR(path: string): string { return localize('mls.cannotFindR.error', "Cannot find R executable '{0}'. Please make sure R is installed and configured correctly", path); }
|
||||
|
||||
@@ -283,3 +283,34 @@ export function getPythonExeName(): string {
|
||||
export function getUserHome(): string | undefined {
|
||||
return process.env.HOME || process.env.USERPROFILE;
|
||||
}
|
||||
|
||||
export function getKeyValueString(key: string, value: string, separator: string = '='): string {
|
||||
return `${key}${separator}${value}`;
|
||||
}
|
||||
|
||||
export function getServerPort(connection: azdata.connection.ConnectionProfile): string {
|
||||
if (!connection) {
|
||||
return '';
|
||||
}
|
||||
let index = connection.serverName.indexOf(',');
|
||||
if (index > 0) {
|
||||
return connection.serverName.substring(index + 1);
|
||||
} else {
|
||||
return '1433';
|
||||
}
|
||||
}
|
||||
|
||||
export function getServerName(connection: azdata.connection.ConnectionProfile): string {
|
||||
if (!connection) {
|
||||
return '';
|
||||
}
|
||||
let index = connection.serverName.indexOf(',');
|
||||
if (index > 0) {
|
||||
return connection.serverName.substring(0, index);
|
||||
} else {
|
||||
return connection.serverName;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -120,6 +120,8 @@ export class PackageManager {
|
||||
await utils.executeTasks(this._apiWrapper, constants.installPackageMngDependenciesMsgTaskName, [
|
||||
this.installRequiredPythonPackages(this._config.requiredSqlPythonPackages),
|
||||
this.installRequiredRPackages()], true);
|
||||
|
||||
await this.verifyOdbcInstalled();
|
||||
}
|
||||
|
||||
private async installRequiredRPackages(): Promise<void> {
|
||||
@@ -186,6 +188,45 @@ export class PackageManager {
|
||||
}
|
||||
}
|
||||
|
||||
private async verifyOdbcInstalled(): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
const separator = '=';
|
||||
let connectionParts: string[] = [];
|
||||
if (connection) {
|
||||
connectionParts.push(utils.getKeyValueString('DRIVER', `{${constants.supportedODBCDriver}}`, separator));
|
||||
|
||||
if (connection.userName) {
|
||||
connectionParts.push(utils.getKeyValueString('UID', connection.userName, separator));
|
||||
connectionParts.push(utils.getKeyValueString('PWD', credentials[azdata.ConnectionOptionSpecialType.password], separator));
|
||||
} else {
|
||||
connectionParts.push(utils.getKeyValueString('Trusted_Connection', 'yes', separator));
|
||||
|
||||
}
|
||||
|
||||
connectionParts.push(utils.getKeyValueString('SERVER', connection.serverName, separator));
|
||||
}
|
||||
|
||||
let scripts: string[] = [
|
||||
'import pyodbc',
|
||||
`connection = pyodbc.connect('${connectionParts.join(';')}')`,
|
||||
'cursor = connection.cursor()',
|
||||
'cursor.execute("SELECT @@version;")'
|
||||
];
|
||||
let pythonExecutable = await this._config.getPythonExecutable(true);
|
||||
try {
|
||||
await this._processService.execScripts(pythonExecutable, scripts, [], this._outputChannel);
|
||||
} catch (err) {
|
||||
const result = await this._apiWrapper.showErrorMessage(constants.verifyOdbcDriverError, constants.learnMoreTitle);
|
||||
if (result === constants.learnMoreTitle) {
|
||||
await this._apiWrapper.openExternal(vscode.Uri.parse(constants.odbcDriverDocuments));
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
try {
|
||||
let pythonExecutable = await this.getPythonExecutable();
|
||||
|
||||
@@ -13,6 +13,7 @@ import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProvide
|
||||
import { HttpClient } from '../common/httpClient';
|
||||
import * as utils from '../common/utils';
|
||||
import { PackageManagementService } from './packageManagementService';
|
||||
import * as constants from '../common/constants';
|
||||
|
||||
/**
|
||||
* Manage Package Provider for python packages inside SQL server databases
|
||||
@@ -62,26 +63,31 @@ export class SqlPythonPackageManageProvider extends SqlPackageManageProviderBase
|
||||
protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
let connectionParts: string[] = [];
|
||||
|
||||
if (connection) {
|
||||
let port = '1433';
|
||||
let server = connection.serverName;
|
||||
let database = databaseName ? `, database="${databaseName}"` : '';
|
||||
const auth = connection.userName ? `, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"` : '';
|
||||
let index = connection.serverName.indexOf(',');
|
||||
if (index > 0) {
|
||||
port = connection.serverName.substring(index + 1);
|
||||
server = connection.serverName.substring(0, index);
|
||||
connectionParts.push(utils.getKeyValueString('driver', `"${constants.supportedODBCDriver}"`));
|
||||
|
||||
let port = utils.getServerPort(connection);
|
||||
let server = utils.getServerName(connection);
|
||||
if (databaseName) {
|
||||
connectionParts.push(utils.getKeyValueString('database', `"${databaseName}"`));
|
||||
}
|
||||
if (connection.userName) {
|
||||
connectionParts.push(utils.getKeyValueString('uid', `"${connection.userName}"`));
|
||||
connectionParts.push(utils.getKeyValueString('pwd', `"${credentials[azdata.ConnectionOptionSpecialType.password]}"`));
|
||||
}
|
||||
|
||||
let pythonConnectionParts = `server="${server}", port=${port}${auth}${database})`;
|
||||
connectionParts.push(utils.getKeyValueString('server', `"${server}"`));
|
||||
connectionParts.push(utils.getKeyValueString('port', port));
|
||||
|
||||
let pythonCommandScript = scriptMode === ScriptMode.Install ?
|
||||
`pkgmanager.install(package="${packageDetails.name}", version="${packageDetails.version}")` :
|
||||
`pkgmanager.uninstall(package_name="${packageDetails.name}")`;
|
||||
|
||||
let scripts: string[] = [
|
||||
'import sqlmlutils',
|
||||
`connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 17 for SQL Server", ${pythonConnectionParts}`,
|
||||
`connection = sqlmlutils.ConnectionInfo(${connectionParts.join(',')})`,
|
||||
'pkgmanager = sqlmlutils.SQLPackageManager(connection)',
|
||||
pythonCommandScript
|
||||
];
|
||||
|
||||
@@ -14,7 +14,7 @@ import { SqlPackageManageProviderBase, ScriptMode } from './packageManageProvide
|
||||
import { HttpClient } from '../common/httpClient';
|
||||
import * as constants from '../common/constants';
|
||||
import { PackageManagementService } from './packageManagementService';
|
||||
|
||||
import * as utils from '../common/utils';
|
||||
|
||||
|
||||
/**
|
||||
@@ -66,18 +66,26 @@ export class SqlRPackageManageProvider extends SqlPackageManageProviderBase impl
|
||||
protected async executeScripts(scriptMode: ScriptMode, packageDetails: nbExtensionApis.IPackageDetails, databaseName: string): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
let connectionParts: string[] = [];
|
||||
|
||||
if (connection) {
|
||||
connectionParts.push(utils.getKeyValueString('driver', constants.supportedODBCDriver));
|
||||
let server = connection.serverName.replace('\\', '\\\\');
|
||||
let database = databaseName ? `, database="${databaseName}"` : '';
|
||||
const auth = connection.userName ? `, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"` : '';
|
||||
let connectionParts = `server="${server}"${auth}${database}`;
|
||||
if (databaseName) {
|
||||
connectionParts.push(utils.getKeyValueString('database', `"${databaseName}"`));
|
||||
}
|
||||
if (connection.userName) {
|
||||
connectionParts.push(utils.getKeyValueString('uid', `"${connection.userName}"`));
|
||||
connectionParts.push(utils.getKeyValueString('pwd', `"${credentials[azdata.ConnectionOptionSpecialType.password]}"`));
|
||||
}
|
||||
connectionParts.push(utils.getKeyValueString('server', `"${server}"`));
|
||||
|
||||
let rCommandScript = scriptMode === ScriptMode.Install ? 'sql_install.packages' : 'sql_remove.packages';
|
||||
|
||||
let scripts: string[] = [
|
||||
'formals(quit)$save <- formals(q)$save <- "no"',
|
||||
'library(sqlmlutils)',
|
||||
`connection <- connectionInfo(driver= "ODBC Driver 17 for SQL Server", ${connectionParts})`,
|
||||
`connection <- connectionInfo(${connectionParts.join(', ')})`,
|
||||
`r = getOption("repos")`,
|
||||
`r["CRAN"] = "${this._config.rPackagesRepository}"`,
|
||||
`options(repos = r)`,
|
||||
|
||||
@@ -10,6 +10,7 @@ import 'mocha';
|
||||
import * as TypeMoq from 'typemoq';
|
||||
import { PackageManager } from '../../packageManagement/packageManager';
|
||||
import { createContext, TestContext } from './utils';
|
||||
import * as constants from '../../common/constants';
|
||||
|
||||
describe('Package Manager', () => {
|
||||
it('Should initialize SQL package manager successfully', async function (): Promise<void> {
|
||||
@@ -114,6 +115,56 @@ describe('Package Manager', () => {
|
||||
should.equal(packagesInstalled, false);
|
||||
});
|
||||
|
||||
it('installDependencies Should fail if odbc not installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"},
|
||||
{"name":"sqlmlutils","version":"1.1.1"}
|
||||
]`;
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(''));
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
return Promise.resolve(installedPackages);
|
||||
});
|
||||
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject('error'));
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await should(packageManager.installDependencies()).be.rejected();
|
||||
});
|
||||
|
||||
it('installDependencies should open link for odbc document if user selects the link', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let connection = new azdata.connection.ConnectionProfile();
|
||||
let credentials = { [azdata.ConnectionOptionSpecialType.password]: 'password' };
|
||||
testContext.apiWrapper.setup(x => x.getCurrentConnection()).returns(() => { return Promise.resolve(connection); });
|
||||
testContext.apiWrapper.setup(x => x.getCredentials(TypeMoq.It.isAny())).returns(() => { return Promise.resolve(credentials); });
|
||||
let installedPackages = `[
|
||||
{"name":"pymssql","version":"2.1.4"},
|
||||
{"name":"sqlmlutils","version":"1.1.1"}
|
||||
]`;
|
||||
testContext.apiWrapper.setup(x => x.startBackgroundOperation(TypeMoq.It.isAny())).returns((operationInfo: azdata.BackgroundOperationInfo) => {
|
||||
operationInfo.operation(testContext.op);
|
||||
});
|
||||
testContext.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve(constants.learnMoreTitle));
|
||||
testContext.apiWrapper.setup(x => x.openExternal(TypeMoq.It.isAny())).returns(() => Promise.resolve(true));
|
||||
testContext.processService.setup(x => x.executeBufferedCommand(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => {
|
||||
return Promise.resolve(installedPackages);
|
||||
});
|
||||
|
||||
testContext.processService.setup(x => x.execScripts(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.reject('error'));
|
||||
|
||||
let packageManager = createPackageManager(testContext);
|
||||
await should(packageManager.installDependencies()).be.rejected();
|
||||
testContext.apiWrapper.verify(x => x.openExternal(TypeMoq.It.isAny()), TypeMoq.Times.atMostOnce());
|
||||
});
|
||||
|
||||
it('installDependencies Should install packages that are not already installed', async function (): Promise<void> {
|
||||
let testContext = createContext();
|
||||
let packagesInstalled = false;
|
||||
|
||||
Reference in New Issue
Block a user