From fde114bb1480125580dea9f1430f4d61f9aabe46 Mon Sep 17 00:00:00 2001 From: Charles Gagnon Date: Mon, 30 Aug 2021 09:07:33 -0700 Subject: [PATCH] Add STS root folder override (#16927) * Add STS root folder override * Display message to user * Show once for any service --- .../src/credentialstore/credentialstore.ts | 14 ++-- .../src/resourceProvider/resourceProvider.ts | 14 ++-- extensions/mssql/src/sqlToolsServer.ts | 71 +++++++++---------- extensions/mssql/src/utils.ts | 44 ++++++++++++ 4 files changed, 89 insertions(+), 54 deletions(-) diff --git a/extensions/mssql/src/credentialstore/credentialstore.ts b/extensions/mssql/src/credentialstore/credentialstore.ts index 86f39d6879..b316dda9ca 100644 --- a/extensions/mssql/src/credentialstore/credentialstore.ts +++ b/extensions/mssql/src/credentialstore/credentialstore.ts @@ -4,7 +4,7 @@ *--------------------------------------------------------------------------------------------*/ import { SqlOpsDataClient, ClientOptions, SqlOpsFeature } from 'dataprotocol-client'; -import { IConfig, ServerProvider } from '@microsoft/ads-service-downloader'; +import { IConfig } from '@microsoft/ads-service-downloader'; import { ServerOptions, RPCMessageType, ClientCapabilities, ServerCapabilities, TransportKind } from 'vscode-languageclient'; import { Disposable } from 'vscode'; import * as UUID from 'vscode-languageclient/lib/utils/uuid'; @@ -77,17 +77,15 @@ export class CredentialStore { } } - public start() { - let serverdownloader = new ServerProvider(this._config); + public async start(): Promise { let clientOptions: ClientOptions = { providerId: Constants.providerId, features: [CredentialsFeature] }; - return serverdownloader.getOrDownloadServer().then(e => { - let serverOptions = this.generateServerOptions(e); - this._client = new SqlOpsDataClient(Constants.serviceName, serverOptions, clientOptions); - this._client.start(); - }); + const serverPath = await Utils.getOrDownloadServer(this._config); + const serverOptions = this.generateServerOptions(serverPath); + this._client = new SqlOpsDataClient(Constants.serviceName, serverOptions, clientOptions); + this._client.start(); } dispose() { diff --git a/extensions/mssql/src/resourceProvider/resourceProvider.ts b/extensions/mssql/src/resourceProvider/resourceProvider.ts index 41343fc095..5df4e25e93 100644 --- a/extensions/mssql/src/resourceProvider/resourceProvider.ts +++ b/extensions/mssql/src/resourceProvider/resourceProvider.ts @@ -4,7 +4,7 @@ *--------------------------------------------------------------------------------------------*/ import * as azdata from 'azdata'; -import { IConfig, ServerProvider } from '@microsoft/ads-service-downloader'; +import { IConfig } from '@microsoft/ads-service-downloader'; import { SqlOpsDataClient, SqlOpsFeature, ClientOptions } from 'dataprotocol-client'; import { ServerCapabilities, ClientCapabilities, RPCMessageType, ServerOptions, TransportKind } from 'vscode-languageclient'; import * as UUID from 'vscode-languageclient/lib/utils/uuid'; @@ -81,17 +81,15 @@ export class AzureResourceProvider { } } - public start() { - let serverdownloader = new ServerProvider(this._config); + public async start(): Promise { let clientOptions: ClientOptions = { providerId: Constants.providerId, features: [FireWallFeature] }; - return serverdownloader.getOrDownloadServer().then(e => { - let serverOptions = this.generateServerOptions(e); - this._client = new SqlOpsDataClient(Constants.serviceName, serverOptions, clientOptions); - this._client.start(); - }); + const serverPath = await Utils.getOrDownloadServer(this._config); + let serverOptions = this.generateServerOptions(serverPath); + this._client = new SqlOpsDataClient(Constants.serviceName, serverOptions, clientOptions); + this._client.start(); } public dispose() { diff --git a/extensions/mssql/src/sqlToolsServer.ts b/extensions/mssql/src/sqlToolsServer.ts index e35796d3a0..8e56c5f28e 100644 --- a/extensions/mssql/src/sqlToolsServer.ts +++ b/extensions/mssql/src/sqlToolsServer.ts @@ -3,12 +3,12 @@ * Licensed under the Source EULA. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ -import { ServerProvider, IConfig, Events } from '@microsoft/ads-service-downloader'; +import { IConfig, Events } from '@microsoft/ads-service-downloader'; import { ServerOptions, TransportKind } from 'vscode-languageclient'; import * as Constants from './constants'; import * as vscode from 'vscode'; import * as path from 'path'; -import { getCommonLaunchArgsAndCleanupOldLogFiles } from './utils'; +import { getCommonLaunchArgsAndCleanupOldLogFiles, getOrDownloadServer } from './utils'; import { Telemetry, LanguageClientErrorHandler } from './telemetry'; import { SqlOpsDataClient, ClientOptions } from 'dataprotocol-client'; import { TelemetryFeature, AgentServicesFeature, SerializationFeature, AccountFeature, SqlAssessmentServicesFeature, ProfilerFeature } from './features'; @@ -82,10 +82,7 @@ export class SqlToolsServer { this.config.installDirectory = path.join(configDir, this.config.installDirectory); this.config.proxy = vscode.workspace.getConfiguration('http').get('proxy'); this.config.strictSSL = vscode.workspace.getConfiguration('http').get('proxyStrictSSL') || true; - - const serverdownloader = new ServerProvider(this.config); - serverdownloader.eventEmitter.onAny(generateHandleServerProviderEvent()); - return serverdownloader.getOrDownloadServer(); + return getOrDownloadServer(this.config, handleServerProviderEvent); } private activateFeatures(context: AppContext): Promise { @@ -109,39 +106,37 @@ function generateServerOptions(logPath: string, executablePath: string): ServerO return { command: executablePath, args: launchArgs, transport: TransportKind.stdio }; } -function generateHandleServerProviderEvent() { +function handleServerProviderEvent(e: string, ...args: any[]): void { let dots = 0; - return (e: string, ...args: any[]) => { - switch (e) { - case Events.INSTALL_START: - outputChannel.show(true); - statusView.show(); - outputChannel.appendLine(localize('installingServiceChannelMsg', "Installing {0} to {1}", Constants.serviceName, args[0])); - statusView.text = localize('installingServiceStatusMsg', "Installing {0}", Constants.serviceName); - break; - case Events.INSTALL_END: - outputChannel.appendLine(localize('installedServiceChannelMsg', "Installed {0}", Constants.serviceName)); - break; - case Events.DOWNLOAD_START: - outputChannel.appendLine(localize('downloadingServiceChannelMsg', "Downloading {0}", args[0])); - outputChannel.append(localize('downloadingServiceSizeChannelMsg', "({0} KB)", Math.ceil(args[1] / 1024).toLocaleString(vscode.env.language))); - statusView.text = localize('downloadingServiceStatusMsg', "Downloading {0}", Constants.serviceName); - break; - case Events.DOWNLOAD_PROGRESS: - let newDots = Math.ceil(args[0] / 5); - if (newDots > dots) { - outputChannel.append('.'.repeat(newDots - dots)); - dots = newDots; - } - break; - case Events.DOWNLOAD_END: - outputChannel.appendLine(localize('downloadServiceDoneChannelMsg', "Done installing {0}", Constants.serviceName)); - break; - case Events.ENTRY_EXTRACTED: - outputChannel.appendLine(localize('entryExtractedChannelMsg', "Extracted {0} ({1}/{2})", args[0], args[1], args[2])); - break; - } - }; + switch (e) { + case Events.INSTALL_START: + outputChannel.show(true); + statusView.show(); + outputChannel.appendLine(localize('installingServiceChannelMsg', "Installing {0} to {1}", Constants.serviceName, args[0])); + statusView.text = localize('installingServiceStatusMsg', "Installing {0}", Constants.serviceName); + break; + case Events.INSTALL_END: + outputChannel.appendLine(localize('installedServiceChannelMsg', "Installed {0}", Constants.serviceName)); + break; + case Events.DOWNLOAD_START: + outputChannel.appendLine(localize('downloadingServiceChannelMsg', "Downloading {0}", args[0])); + outputChannel.append(localize('downloadingServiceSizeChannelMsg', "({0} KB)", Math.ceil(args[1] / 1024).toLocaleString(vscode.env.language))); + statusView.text = localize('downloadingServiceStatusMsg', "Downloading {0}", Constants.serviceName); + break; + case Events.DOWNLOAD_PROGRESS: + let newDots = Math.ceil(args[0] / 5); + if (newDots > dots) { + outputChannel.append('.'.repeat(newDots - dots)); + dots = newDots; + } + break; + case Events.DOWNLOAD_END: + outputChannel.appendLine(localize('downloadServiceDoneChannelMsg', "Done installing {0}", Constants.serviceName)); + break; + case Events.ENTRY_EXTRACTED: + outputChannel.appendLine(localize('entryExtractedChannelMsg', "Extracted {0} ({1}/{2})", args[0], args[1], args[2])); + break; + } } function getClientOptions(context: AppContext): ClientOptions { diff --git a/extensions/mssql/src/utils.ts b/extensions/mssql/src/utils.ts index e1a22e6103..adac3ba78a 100644 --- a/extensions/mssql/src/utils.ts +++ b/extensions/mssql/src/utils.ts @@ -12,6 +12,8 @@ import * as os from 'os'; import * as findRemoveSync from 'find-remove'; import * as constants from './constants'; import { promises as fs } from 'fs'; +import { IConfig, ServerProvider } from '@microsoft/ads-service-downloader'; +import { env } from 'process'; const configTracingLevel = 'tracingLevel'; const configLogRetentionMinutes = 'logRetentionMinutes'; @@ -304,3 +306,45 @@ export async function exists(path: string): Promise { return false; } } + +const STS_OVERRIDE_ENV_VAR = 'ADS_SQLTOOLSSERVICE'; +let overrideMessageDisplayed = false; +/** + * Gets the full path to the EXE for the specified tools service, downloading it in the process if necessary. The location + * for this can be overridden with an environment variable for debugging or other purposes. + * @param config The configuration values of the server to get/download + * @param handleServerEvent A callback for handling events from the server downloader + * @returns The path to the server exe + */ +export async function getOrDownloadServer(config: IConfig, handleServerEvent?: (e: string, ...args: any[]) => void): Promise { + // This env var is used to override the base install location of STS - primarily to be used for debugging scenarios. + try { + const stsRootPath = env[STS_OVERRIDE_ENV_VAR]; + if (stsRootPath) { + for (const exeFile of config.executableFiles) { + const serverFullPath = path.join(stsRootPath, exeFile); + if (await exists(serverFullPath)) { + const overrideMessage = `Using ${exeFile} from ${stsRootPath}`; + // Display message to the user so they know the override is active, but only once so we don't show too many + if (!overrideMessageDisplayed) { + overrideMessageDisplayed = true; + vscode.window.showInformationMessage(overrideMessage); + } + console.log(overrideMessage); + return serverFullPath; + } + } + console.warn(`Could not find valid SQL Tools Service EXE from ${JSON.stringify(config.executableFiles)} at ${stsRootPath}, falling back to config`); + } + } catch (err) { + console.warn('Unexpected error getting override path for SQL Tools Service client ', err); + // Fall back to config if something unexpected happens here + } + + const serverdownloader = new ServerProvider(config); + if (handleServerEvent) { + serverdownloader.eventEmitter.onAny(handleServerEvent); + } + + return serverdownloader.getOrDownloadServer(); +}