mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 10:58:30 -05:00
Machine Learning Services extension with package management feature (#8622)
* Machine Learning Services extension with package management feature
This commit is contained in:
@@ -0,0 +1,156 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
'use strict';
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
import { SqlPythonPackageManageProvider } from './sqlPackageManageProvider';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import * as utils from '../common/utils';
|
||||
import * as constants from '../common/constants';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { ProcessService } from '../common/processService';
|
||||
import { Config } from '../common/config';
|
||||
import { isNullOrUndefined } from 'util';
|
||||
|
||||
export class PackageManager {
|
||||
|
||||
private _pythonExecutable: string = '';
|
||||
private _pythonInstallationLocation: string = '';
|
||||
private _sqlPackageManager: SqlPythonPackageManageProvider | undefined = undefined;
|
||||
|
||||
/**
|
||||
* Creates a new instance of PackageManager
|
||||
*/
|
||||
constructor(
|
||||
private _nbExtensionApis: nbExtensionApis.IExtensionApi,
|
||||
private _outputChannel: vscode.OutputChannel,
|
||||
private _rootFolder: string,
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _queryRunner: QueryRunner,
|
||||
private _processService: ProcessService,
|
||||
private _config: Config) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the instance and resister SQL package manager with manage package dialog
|
||||
*/
|
||||
public init(): void {
|
||||
this._pythonInstallationLocation = utils.getPythonInstallationLocation(this._rootFolder);
|
||||
this._pythonExecutable = utils.getPythonExePath(this._rootFolder);
|
||||
this._sqlPackageManager = new SqlPythonPackageManageProvider(this._nbExtensionApis, this._outputChannel, this._rootFolder, this._apiWrapper, this._queryRunner, this._processService);
|
||||
this._nbExtensionApis.registerPackageManager(SqlPythonPackageManageProvider.ProviderId, this._sqlPackageManager);
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes manage package command for SQL server packages.
|
||||
*/
|
||||
public async managePackages(): Promise<void> {
|
||||
|
||||
// Only execute the command if there's a valid connection with ml configuration enabled
|
||||
//
|
||||
let connection = await this.getCurrentConnection();
|
||||
let isPythonInstalled = await this._queryRunner.isPythonInstalled(connection);
|
||||
if (connection && isPythonInstalled && this._sqlPackageManager) {
|
||||
this._apiWrapper.executeCommand(constants.managePackagesCommand, {
|
||||
multiLocations: false,
|
||||
defaultLocation: this._sqlPackageManager.packageTarget.location,
|
||||
defaultProviderId: SqlPythonPackageManageProvider.ProviderId
|
||||
});
|
||||
} else {
|
||||
this._apiWrapper.showInfoMessage(constants.managePackageCommandError);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs dependencies for the extension
|
||||
*/
|
||||
public async installDependencies(): Promise<void> {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
let msgTaskName = constants.installDependenciesMsgTaskName;
|
||||
this._apiWrapper.startBackgroundOperation({
|
||||
displayName: msgTaskName,
|
||||
description: msgTaskName,
|
||||
isCancelable: false,
|
||||
operation: async op => {
|
||||
try {
|
||||
if (!(await utils.exists(this._pythonExecutable))) {
|
||||
// Install python
|
||||
//
|
||||
await utils.createFolder(this._pythonInstallationLocation);
|
||||
await this.jupyterInstallation.installPythonPackage(op, false, this._pythonInstallationLocation, this._outputChannel);
|
||||
}
|
||||
|
||||
// Install required packages
|
||||
//
|
||||
await this.installRequiredPythonPackages();
|
||||
op.updateStatus(azdata.TaskStatus.Succeeded);
|
||||
resolve();
|
||||
} catch (error) {
|
||||
let errorMsg = constants.installDependenciesError(error ? error.message : '');
|
||||
op.updateStatus(azdata.TaskStatus.Failed, errorMsg);
|
||||
reject(errorMsg);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs required python packages
|
||||
*/
|
||||
private async installRequiredPythonPackages(): Promise<void> {
|
||||
let installedPackages = await this.getInstalledPipPackages();
|
||||
let fileContent = '';
|
||||
this._config.requiredPythonPackages.forEach(packageDetails => {
|
||||
let hasVersion = ('version' in packageDetails) && !isNullOrUndefined(packageDetails['version']) && packageDetails['version'].length > 0;
|
||||
if (!installedPackages.find(x => x.name === packageDetails['name'] && (!hasVersion || packageDetails['version'] === x.version))) {
|
||||
let packageNameDetail = hasVersion ? `${packageDetails.name}==${packageDetails.version}` : `${packageDetails.name}`;
|
||||
fileContent = `${fileContent}${packageNameDetail}\n`;
|
||||
}
|
||||
});
|
||||
|
||||
if (fileContent) {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackages);
|
||||
let result = await utils.execCommandOnTempFile<string>(fileContent, async (tempFilePath) => {
|
||||
return await this.installPackages(tempFilePath);
|
||||
});
|
||||
this._outputChannel.appendLine(result);
|
||||
} else {
|
||||
this._outputChannel.appendLine(constants.installDependenciesPackagesAlreadyInstalled);
|
||||
}
|
||||
}
|
||||
|
||||
private async getInstalledPipPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
try {
|
||||
let cmd = `"${this._pythonExecutable}" -m pip list --format=json`;
|
||||
let packagesInfo = await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
let packagesResult: nbExtensionApis.IPackageDetails[] = [];
|
||||
if (packagesInfo) {
|
||||
packagesResult = <nbExtensionApis.IPackageDetails[]>JSON.parse(packagesInfo);
|
||||
}
|
||||
return packagesResult;
|
||||
}
|
||||
catch (err) {
|
||||
this._outputChannel.appendLine(constants.installDependenciesGetPackagesError(err ? err.message : ''));
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
|
||||
private get jupyterInstallation(): nbExtensionApis.IJupyterServerInstallation {
|
||||
return this._nbExtensionApis.getJupyterController().jupyterInstallation;
|
||||
}
|
||||
|
||||
private async installPackages(requirementFilePath: string): Promise<string> {
|
||||
let cmd = `"${this._pythonExecutable}" -m pip install -r "${requirementFilePath}"`;
|
||||
return await this._processService.executeBufferedCommand(cmd, this._outputChannel);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
'use strict';
|
||||
|
||||
import * as vscode from 'vscode';
|
||||
import * as azdata from 'azdata';
|
||||
import * as nbExtensionApis from '../typings/notebookServices';
|
||||
import * as utils from '../common/utils';
|
||||
import * as constants from '../common/constants';
|
||||
import { QueryRunner } from '../common/queryRunner';
|
||||
import { ApiWrapper } from '../common/apiWrapper';
|
||||
import { ProcessService } from '../common/processService';
|
||||
|
||||
const installMode = 'install';
|
||||
const uninstallMode = 'uninstall';
|
||||
const localPythonProviderId = 'localhost_Pip';
|
||||
|
||||
/**
|
||||
* Manage Package Provider for python packages inside SQL server databases
|
||||
*/
|
||||
export class SqlPythonPackageManageProvider implements nbExtensionApis.IPackageManageProvider {
|
||||
|
||||
private _pythonExecutable: string;
|
||||
|
||||
public static ProviderId = 'sql_Python';
|
||||
|
||||
/**
|
||||
* Creates new a instance
|
||||
*/
|
||||
constructor(
|
||||
private _nbExtensionApis: nbExtensionApis.IExtensionApi,
|
||||
private _outputChannel: vscode.OutputChannel,
|
||||
private _rootFolder: string,
|
||||
private _apiWrapper: ApiWrapper,
|
||||
private _queryRunner: QueryRunner,
|
||||
private _processService: ProcessService) {
|
||||
this._pythonExecutable = utils.getPythonExePath(this._rootFolder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns provider Id
|
||||
*/
|
||||
public get providerId(): string {
|
||||
return SqlPythonPackageManageProvider.ProviderId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns package target
|
||||
*/
|
||||
public get packageTarget(): nbExtensionApis.IPackageTarget {
|
||||
return { location: 'SQL', packageType: 'Python' };
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns list of packages
|
||||
*/
|
||||
public async listPackages(): Promise<nbExtensionApis.IPackageDetails[]> {
|
||||
let packages = await this._queryRunner.getPythonPackages(await this.getCurrentConnection());
|
||||
if (packages) {
|
||||
packages = packages.sort((a, b) => a.name.localeCompare(b.name));
|
||||
} else {
|
||||
packages = [];
|
||||
}
|
||||
return packages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Installs given packages
|
||||
* @param packages Packages to install
|
||||
* @param useMinVersion minimum version
|
||||
*/
|
||||
async installPackages(packages: nbExtensionApis.IPackageDetails[], useMinVersion: boolean): Promise<void> {
|
||||
if (packages) {
|
||||
|
||||
// TODO: install package as parallel
|
||||
for (let index = 0; index < packages.length; index++) {
|
||||
const element = packages[index];
|
||||
await this.updatePackage(element, installMode);
|
||||
}
|
||||
}
|
||||
//TODO: use useMinVersion
|
||||
console.log(useMinVersion);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a script to install or uninstall a python package inside current SQL Server connection
|
||||
* @param packageDetails Packages to install or uninstall
|
||||
* @param scriptMode can be 'install' or 'uninstall'
|
||||
*/
|
||||
private async updatePackage(packageDetails: nbExtensionApis.IPackageDetails, scriptMode: string): Promise<void> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
let credentials = await this._apiWrapper.getCredentials(connection.connectionId);
|
||||
|
||||
if (connection) {
|
||||
let port = '1433';
|
||||
let server = connection.serverName;
|
||||
let database = connection.databaseName ? `, database="${connection.databaseName}"` : '';
|
||||
let index = connection.serverName.indexOf(',');
|
||||
if (index > 0) {
|
||||
port = connection.serverName.substring(index + 1);
|
||||
server = connection.serverName.substring(0, index);
|
||||
}
|
||||
|
||||
let pythonConnectionParts = `server="${server}", port=${port}, uid="${connection.userName}", pwd="${credentials[azdata.ConnectionOptionSpecialType.password]}"${database})`;
|
||||
let pythonCommandScript = scriptMode === installMode ?
|
||||
`pkgmanager.install(package="${packageDetails.name}", version="${packageDetails.version}")` :
|
||||
`pkgmanager.uninstall(package_name="${packageDetails.name}")`;
|
||||
|
||||
let scripts: string[] = [
|
||||
'import sqlmlutils',
|
||||
`connection = sqlmlutils.ConnectionInfo(driver="ODBC Driver 17 for SQL Server", ${pythonConnectionParts}`,
|
||||
'pkgmanager = sqlmlutils.SQLPackageManager(connection)',
|
||||
pythonCommandScript
|
||||
];
|
||||
await this._processService.execScripts(this._pythonExecutable, scripts, this._outputChannel);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Uninstalls given packages
|
||||
* @param packages Packages to uninstall
|
||||
*/
|
||||
async uninstallPackages(packages: nbExtensionApis.IPackageDetails[]): Promise<void> {
|
||||
for (let index = 0; index < packages.length; index++) {
|
||||
const element = packages[index];
|
||||
await this.updatePackage(element, uninstallMode);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the provider can be used
|
||||
*/
|
||||
async canUseProvider(): Promise<boolean> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection && await this._queryRunner.isPythonInstalled(connection)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns package overview for given name
|
||||
* @param packageName Package Name
|
||||
*/
|
||||
async getPackageOverview(packageName: string): Promise<nbExtensionApis.IPackageOverview> {
|
||||
let packagePreview: nbExtensionApis.IPackageOverview = {
|
||||
name: packageName,
|
||||
versions: [],
|
||||
summary: ''
|
||||
};
|
||||
let pythonPackageProvider = this.pythonPackageProvider;
|
||||
if (pythonPackageProvider) {
|
||||
packagePreview = await pythonPackageProvider.getPackageOverview(packageName);
|
||||
}
|
||||
return packagePreview;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns location title
|
||||
*/
|
||||
async getLocationTitle(): Promise<string> {
|
||||
let connection = await this.getCurrentConnection();
|
||||
if (connection) {
|
||||
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`;
|
||||
}
|
||||
return constants.packageManagerNoConnection;
|
||||
}
|
||||
|
||||
private get pythonPackageProvider(): nbExtensionApis.IPackageManageProvider | undefined {
|
||||
let providers = this._nbExtensionApis.getPackageManagers();
|
||||
if (providers && providers.has(localPythonProviderId)) {
|
||||
return providers.get(localPythonProviderId);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private async getCurrentConnection(): Promise<azdata.connection.ConnectionProfile> {
|
||||
return await this._apiWrapper.getCurrentConnection();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user