mirror of
https://github.com/ckaczor/azuredatastudio.git
synced 2026-02-16 18:46:40 -05:00
Merge vscode source through release 1.79.2 (#23482)
* log when an editor action doesn't run because of enablement * notebooks create/dispose editors. this means controllers must be created eagerly (😢) and that notebooks need a custom way of plugging comparision keys for session. works unless creating another session for the same cell of a duplicated editor * Set offSide to sql lang configuration to true (#183461) * Fixes #181764 (#183550) * fix typo * Always scroll down and focus the input (#183557) * Fixes #180386 (#183561) * cli: ensure ordering of rpc server messages (#183558) * cli: ensure ordering of rpc server messages Sending lots of messages to a stream would block them around the async tokio mutex, which is "fair" so doesn't preserve ordering. Instead, use the write_loop approach I introduced to the server_multiplexer for the same reason some time ago. * fix clippy * update for May endgame * testing: allow invalidateTestResults to take an array (#183569) * Document `ShareProvider` API proposal (#183568) * Document `ShareProvider` API proposal * Remove mention of VS Code from JSDoc * Add support for rendering svg and md in welcome message (#183580) * Remove toggle setting more eagerly (#183584) * rm message abt macOS * Change text (#183589) * Change text * Accidentally changed the wrong file * cli: improve output for code tunnel status (#183571) * testing: allow invalidateTestResults to take an array * cli: improve output for code tunnel status Fixes #183570 * [json/css/html] update services (#183595) * Add experimental setting to enable this dialog * Fix exporting chat model to JSON before it is initialized (#183597) * minimum scrolling to reveal the next cell on shift+enter (#183600) do minimum scrolling to reveal the next cell on Execute cell and select next * Fixing Jupyter notebook issue 13263 (#183527) fix for the issue, still need to understand why there is strange focusing * Tweak proposed API JSDoc (#183590) * Tweak proposed API JSDoc * workbench -> workspace * fix ? operator * Use active editor and show progress when sharing (#183603) Use active editor and show progress * use scroll setting variable correctly * Schedule welcome widget to show once between typing. (#183606) * Schedule dialog to show once between typing * Don't re-render if already displayed once * Add F10 keybinding for debugger step, even on Web. (#183510) Fixes #181792. Previously, for Web the keyboard shortcut was Alt-F10, because it was believed that F10 could not be bound on browsers. This turned out to be incorrect, so we make the shortcut consistent (F10) with desktop VSCode which is also what many other debuggers use. We keep Alt-F10 on web as a secondary keybinding to keep the experience some web users may have gotten used to by now. * Also pass process.env * Restore missing chat clear commands (#183651) * chore: update electron@22.5.4 (#183716) * Show remote indicator in web when remoteAuthority is set (#183728) * feat: .vuerc as json file (#153017) Co-authored-by: Martin Aeschlimann <martinae@microsoft.com> * Delete --compatibility=1.63 code from the server (#183738) * Copy vscode.dev link to tunnel generates an invalid link when an untitled workspace is open (#183739) * Recent devcontainer display string corrupted on Get Started page (#183740) * Improve "next codeblock" navigation (#183744) * Improve "next codeblock" navigation Operate on the current focused response, or the last one, and scroll to the selected item * Normalize command title * Git - run git status if similarityThreshold changes (#183762) * fix aria-label issue in kb editor fixes A11y_GradeB_VSCode_Keyboard shortcut reads words together - Blind: Arrow key navigation to row Find the binding keys and "when" cell data are read together resulting in a word " CTRL + FeditorFocus instead of CTRL + F editorFocus" #182490 * Status - fix compact padding (#183768) * Remove angle brackets from VB brackets (#183782) Fixes #183359 * Update language config schema with more details about brackets. (#183779) * fix comment (#183812) * Support for `Notebook` CodeAction Kind (#183457) * nb kind support -- wip * allow notebook codeactions around single cell edit check * move notebook code action type out of editor --------- Co-authored-by: rebornix <penn.lv@gmail.com> * cli: fix connection default being applied (#183827) * cli: bump to openssl 1.1.1u (#183828) * Implement "delete" action for chat history (#183609) * Use desired file name when generating new md pasted file paths (#183861) Fixes #183851 * Default to filename for markdown new file if empty (#183864) Fixes #183848 * Fix small typo (#183865) Fixes #183819 * Noop when moving a symbol into the file it is already in (#183866) Fixes #183793 * Adjust codeAction validation to account for notebook kind (#183859) * Make JS/TS `go to configuration` commands work on non-`file:` file systems (#183688) Make `go to project` commands work on non-`file:` file systems Fixes #183685 * Can't do regex search after opening notebook (#183884) Fixes #183858 * Default to current dir for `move to file` select (#183875) Fixes #183870 `showOpenDialog` seems to ignore `defaultUri` if the file doesn't exist * Use `<...>` style markdown links when needed (#183876) Fixes #183849 * Remove check for context keys * Update xterm package * Enable updating a chat model without triggering incremental typing (#183894) * Enable chat "move" commands on empty sessions (#183895) * Enable chat "move" commands on empty sessions and also imported sessions * Fix command name * Fix some chat keybindings on windows (#183896) * "Revert File" on inactive editors are ignored (fix #177557) (#183903) * Empty reason while switching profile (fix #183775) (#183904) * fix https://github.com/microsoft/vscode-internalbacklog/issues/4278 (#183910) * fix https://github.com/microsoft/vscode/issues/183770 (#183914) * code --status displays a lot of errors before actual status output (fix #183787) (#183915) * joh/icy manatee (#183917) * Use idle value for widget of interactive editor controller https://github.com/microsoft/vscode/issues/183820 * also make preview editors idle values https://github.com/microsoft/vscode/issues/183820 * Fix #183777 (#183929) * Fix #182309 (#183925) * Tree checkbox item -> items (#183931) Fixes #183826 * Fixes #183909 (#183940) * Fix #183837 (#183943) fix #183837 * Git - fix #183941 (#183944) * Update xterm.css Fixes #181242 * chore: add @ulugbekna and @aiday-mar to my-endgame notebook (#183946) * Revert "When snippet mode is active, make `Tab` not accept suggestion but advance placeholder" This reverts commit 50a80cdb61511343996ff1d41d0b676c3d329f48. * revert not focusing completion list when quick suggest happens during snippet * change `snippetsPreventQuickSuggestions` default to false * Fix #181446 (#183956) * fix https://github.com/microsoft/vscode-internalbacklog/issues/4298 (#183957) * fix: remove extraneous incorrect context keys (#183959) These were actually getting added in getTestItemContextOverlay, and the test ID was using the extended ID which extensions do not know about. Fixes #183612 * Fixes https://github.com/microsoft/monaco-editor/issues/3920 (#183960) * fix https://github.com/microsoft/vscode-internalbacklog/issues/4324 (#183961) * fix #183030 * fix #180826 (#183962) * make message more generic for interactive editor help * . * fix #183968 * Keep codeblock toolbar visible when focused * Fix when clause on "Run in terminal" command * add important info to help menu * fix #183970 * Set `isRefactoring` for all TS refactoring edits (#183982) * consolidate * Disable move to file in TS versions < 5.2 (#183992) There are still a few key bugs with refactoring. We will ship this as a preview for TS 5.2+ instead of for 5.1 * Polish query accepting (#183995) We shouldn't send the same request to Copilot if the query hasn't changed. So if the query is the same, we short circut. Fixes https://github.com/microsoft/vscode-internalbacklog/issues/4286 Also, when we open in chat, we should use the last accepted query, not what's in the input box. Fixes https://github.com/microsoft/vscode-internalbacklog/issues/4280 * Allow widget to have focus (#184000) So that selecting non-code text works. Fixes https://github.com/microsoft/vscode-internalbacklog/issues/4294 * Fix microsoft/vscode-internalbacklog#4257. Mitigate zindex for zone widgets. (#184001) * Change welcome dialog contribution to Eventually * Misc fixes * Workspace folder picker entry descriptions are suboptimal for some filesystems (fix #183418) (#184018) * cli - ignore std error unless verbose (#183787) (#184031) * joh/inquisitive meerkat (#184034) * only stash sessions that are none empty https://github.com/microsoft/vscode-internalbacklog/issues/4281 * only unstash a session once - unless new exchanges are made, https://github.com/microsoft/vscode-internalbacklog/issues/4281 * account for all exchange types * Improve declared components (#184039) * make sure to read setting (#184040) d'oh, related to https://github.com/microsoft/vscode/issues/173387#issuecomment-1571696644 * [html] update service (#184049) [html] update service. FIxes #181176 * reset context keys on reset/hide (#184042) fixes https://github.com/microsoft/vscode-internalbacklog/issues/4330 * use `Lazy`, not `IdleValue` for the IE widget held by the eager controller (#184048) https://github.com/microsoft/vscode/issues/183820 * fix https://github.com/microsoft/vscode-internalbacklog/issues/4333 (#184067) * use undo-loop instead of undo-edit when discarding chat session (#184063) * use undo-loop instead of undo-edit when discarding chat session fixes https://github.com/microsoft/vscode-internalbacklog/issues/4118 * fix tests, wait for correct state * Add logging to node download (#184070) Add logging to node download. For #182951 * re-enable default zone widget revealing when showing (#184072) fixes https://github.com/microsoft/vscode-internalbacklog/issues/4332, also fixes https://github.com/microsoft/vscode-internalbacklog/issues/3784 * fix #178202 * Allow APIs in stable (#184062) * Fix microsoft/vscode-internalbacklog#4206. Override List view whitespace css for monaco editor (#184087) * Fix JSDoc grammatical error (#184090) * Pick up TS 5.1.3 (#184091) Fixes #182931 * Misc fixes * update distro (#184097) * chore: update electron@22.5.5 (#184116) * Extension host veto is registered multiple times on restart (fix #183778) (#184127) Extension host veto is registered multiple times on restart (#183778) * Do not auto start the local web worker extension host (#184137) * Allow embedders to intercept trustedTypes.createPolicy calls (#184136) Allow embedders to intercept trustedTypes.createPolicy calls (#184100) * fix: reading from console output for --status on windows and linux (#184138) fix: reading from console output for --status on windows and linux (#184118) * Misc fixes * code --status displays a lot of errors before actual status output (fix #183787) (#184200) fix 183787 * (cherry-pick to 1.79 from main) Handle galleryExtension failure in featuredExtensionService (#184205) Handle galleryExtension failure in featuredExtensionService (#184198) Handle galleryExtension failure * Fix #184183. Multiple output height updates are skipped. (#184188) * Post merge init fixes * Misc build issues * disable toggle inline diff of `alt` down https://github.com/microsoft/vscode-internalbacklog/issues/4342 * Take into account already activated extensions when computing running locations (#184303) Take into account already activated extensions when computing running locations (fixes #184180) * Avoid `extensionService.getExtension` and use `ActivationKind.Immediate` to allow that URI handling works while resolving (#184310) Avoid `extensionService.getExtension` and use `ActivationKind.Immediate` to allow that URI handling works while resolving (fixes #182217) * WIP * rm fish auto injection * More breaks * Fix Port Attributes constructor (#184412) * WIP * WIP * Allow extensions to get at the exports of other extensions during resolving (#184487) Allow extensions to get at the exports of other extensions during resolving (fixes #184472) * do not auto finish session when inline chat widgets have focus re https://github.com/microsoft/vscode-internalbacklog/issues/4354 * fix compile errors caused by new base method * WIP * WIP * WIP * WIP * Build errors * unc - fix path traversal bypass * Bump version * cherry-pick prod changes from main * Disable sandbox * Build break from merge * bump version * Merge pull request #184739 from max06/max06/issue184659 Restore ShellIntegration for fish (#184659) * Git - only add --find-renames if the value is not the default one (#185053) Git - only add --find-renames if the value is not the default one (#184992) * Cherry-pick: Revert changes to render featured extensions when available (#184747) Revert changes to render featured extensions when available. (#184573) * Lower timeouts for experimentation and gallery service * Revert changes to render extensions when available * Add audio cues * fix: disable app sandbox when --no-sandbox is present (#184913) * fix: disable app sandbox when --no-sandbox is present (#184897) * fix: loading minimist in packaged builds * Runtime errors * UNC allow list checks cannot be disabled in extension host (fix #184989) (#185085) * UNC allow list checks cannot be disabled in extension host (#184989) * Update src/vs/base/node/unc.js Co-authored-by: Robo <hop2deep@gmail.com> --------- Co-authored-by: Robo <hop2deep@gmail.com> * Add notebook extension * Fix mangling issues * Fix mangling issues * npm install * npm install * Issues blocking bundle * Fix build folder compile errors * Fix windows bundle build * Linting fixes * Fix sqllint issues * Update yarn.lock files * Fix unit tests * Fix a couple breaks from test fixes * Bump distro * redo the checkbox style * Update linux build container dockerfile * Bump build image tag * Bump native watch dog package * Bump node-pty * Bump distro * Fix documnetation error * Update distro * redo the button styles * Update datasource TS * Add missing yarn.lock files * Windows setup fix * Turn off extension unit tests while investigating * color box style * Remove appx * Turn off test log upload * update dropdownlist style * fix universal app build error (#23488) * Skip flaky bufferContext vscode test --------- Co-authored-by: Johannes <johannes.rieken@gmail.com> Co-authored-by: Henning Dieterichs <hdieterichs@microsoft.com> Co-authored-by: Julien Richard <jairbubbles@hotmail.com> Co-authored-by: Charles Gagnon <chgagnon@microsoft.com> Co-authored-by: Megan Rogge <merogge@microsoft.com> Co-authored-by: meganrogge <megan.rogge@microsoft.com> Co-authored-by: Rob Lourens <roblourens@gmail.com> Co-authored-by: Connor Peet <connor@peet.io> Co-authored-by: Joyce Er <joyce.er@microsoft.com> Co-authored-by: Bhavya U <bhavyau@microsoft.com> Co-authored-by: Raymond Zhao <7199958+rzhao271@users.noreply.github.com> Co-authored-by: Martin Aeschlimann <martinae@microsoft.com> Co-authored-by: Aaron Munger <aamunger@microsoft.com> Co-authored-by: Aiday Marlen Kyzy <amarlenkyzy@microsoft.com> Co-authored-by: rebornix <penn.lv@gmail.com> Co-authored-by: Ole <oler@google.com> Co-authored-by: Jean Pierre <jeanp413@hotmail.com> Co-authored-by: Robo <hop2deep@gmail.com> Co-authored-by: Yash Singh <saiansh2525@gmail.com> Co-authored-by: Ladislau Szomoru <3372902+lszomoru@users.noreply.github.com> Co-authored-by: Ulugbek Abdullaev <ulugbekna@gmail.com> Co-authored-by: Alex Ross <alros@microsoft.com> Co-authored-by: Michael Lively <milively@microsoft.com> Co-authored-by: Matt Bierner <matb@microsoft.com> Co-authored-by: Andrea Mah <31675041+andreamah@users.noreply.github.com> Co-authored-by: Benjamin Pasero <benjamin.pasero@microsoft.com> Co-authored-by: Sandeep Somavarapu <sasomava@microsoft.com> Co-authored-by: Daniel Imms <2193314+Tyriar@users.noreply.github.com> Co-authored-by: Tyler James Leonhardt <me@tylerleonhardt.com> Co-authored-by: Alexandru Dima <alexdima@microsoft.com> Co-authored-by: Joao Moreno <Joao.Moreno@microsoft.com> Co-authored-by: Alan Ren <alanren@microsoft.com>
This commit is contained in:
41
cli/src/tunnels/challenge.rs
Normal file
41
cli/src/tunnels/challenge.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn create_challenge() -> String {
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
Alphanumeric.sample_string(&mut rand::thread_rng(), 16)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn sign_challenge(challenge: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hash = Sha256::new();
|
||||
hash.update(challenge.as_bytes());
|
||||
let result = hash.finalize();
|
||||
base64::encode_config(result, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "vsda"))]
|
||||
pub fn verify_challenge(challenge: &str, response: &str) -> bool {
|
||||
sign_challenge(challenge) == response
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn create_challenge() -> String {
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
let str = Alphanumeric.sample_string(&mut rand::thread_rng(), 16);
|
||||
vsda::create_new_message(&str)
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn sign_challenge(challenge: &str) -> String {
|
||||
vsda::sign(challenge)
|
||||
}
|
||||
|
||||
#[cfg(feature = "vsda")]
|
||||
pub fn verify_challenge(challenge: &str, response: &str) -> bool {
|
||||
vsda::validate(challenge, response)
|
||||
}
|
||||
799
cli/src/tunnels/code_server.rs
Normal file
799
cli/src/tunnels/code_server.rs
Normal file
@@ -0,0 +1,799 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use super::paths::{InstalledServer, ServerPaths};
|
||||
use crate::async_pipe::get_socket_name;
|
||||
use crate::constants::{
|
||||
APPLICATION_NAME, EDITOR_WEB_URL, QUALITYLESS_PRODUCT_NAME, QUALITYLESS_SERVER_NAME,
|
||||
};
|
||||
use crate::download_cache::DownloadCache;
|
||||
use crate::options::{Quality, TelemetryLevel};
|
||||
use crate::state::LauncherPaths;
|
||||
use crate::tunnels::paths::{get_server_folder_name, SERVER_FOLDER_NAME};
|
||||
use crate::update_service::{
|
||||
unzip_downloaded_release, Platform, Release, TargetKind, UpdateService,
|
||||
};
|
||||
use crate::util::command::{capture_command, kill_tree};
|
||||
use crate::util::errors::{wrap, AnyError, CodeError, ExtensionInstallFailed, WrappedError};
|
||||
use crate::util::http::{self, BoxedHttp};
|
||||
use crate::util::io::SilentCopyProgress;
|
||||
use crate::util::machine::process_exists;
|
||||
use crate::{debug, info, log, spanf, trace, warning};
|
||||
use lazy_static::lazy_static;
|
||||
use opentelemetry::KeyValue;
|
||||
use regex::Regex;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::fs::remove_file;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::oneshot::Receiver;
|
||||
use tokio::time::{interval, timeout};
|
||||
|
||||
lazy_static! {
|
||||
static ref LISTENING_PORT_RE: Regex =
|
||||
Regex::new(r"Extension host agent listening on (.+)").unwrap();
|
||||
static ref WEB_UI_RE: Regex = Regex::new(r"Web UI available at (.+)").unwrap();
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct CodeServerArgs {
|
||||
pub host: Option<String>,
|
||||
pub port: Option<u16>,
|
||||
pub socket_path: Option<String>,
|
||||
|
||||
// common argument
|
||||
pub telemetry_level: Option<TelemetryLevel>,
|
||||
pub log: Option<log::Level>,
|
||||
pub accept_server_license_terms: bool,
|
||||
pub verbose: bool,
|
||||
// extension management
|
||||
pub install_extensions: Vec<String>,
|
||||
pub uninstall_extensions: Vec<String>,
|
||||
pub list_extensions: bool,
|
||||
pub show_versions: bool,
|
||||
pub category: Option<String>,
|
||||
pub pre_release: bool,
|
||||
pub force: bool,
|
||||
pub start_server: bool,
|
||||
// connection tokens
|
||||
pub connection_token: Option<String>,
|
||||
pub connection_token_file: Option<String>,
|
||||
pub without_connection_token: bool,
|
||||
}
|
||||
|
||||
impl CodeServerArgs {
|
||||
pub fn log_level(&self) -> log::Level {
|
||||
if self.verbose {
|
||||
log::Level::Trace
|
||||
} else {
|
||||
self.log.unwrap_or(log::Level::Info)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn telemetry_disabled(&self) -> bool {
|
||||
self.telemetry_level == Some(TelemetryLevel::Off)
|
||||
}
|
||||
|
||||
pub fn command_arguments(&self) -> Vec<String> {
|
||||
let mut args = Vec::new();
|
||||
if let Some(i) = &self.socket_path {
|
||||
args.push(format!("--socket-path={}", i));
|
||||
} else {
|
||||
if let Some(i) = &self.host {
|
||||
args.push(format!("--host={}", i));
|
||||
}
|
||||
if let Some(i) = &self.port {
|
||||
args.push(format!("--port={}", i));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(i) = &self.connection_token {
|
||||
args.push(format!("--connection-token={}", i));
|
||||
}
|
||||
if let Some(i) = &self.connection_token_file {
|
||||
args.push(format!("--connection-token-file={}", i));
|
||||
}
|
||||
if self.without_connection_token {
|
||||
args.push(String::from("--without-connection-token"));
|
||||
}
|
||||
if self.accept_server_license_terms {
|
||||
args.push(String::from("--accept-server-license-terms"));
|
||||
}
|
||||
if let Some(i) = self.telemetry_level {
|
||||
args.push(format!("--telemetry-level={}", i));
|
||||
}
|
||||
if let Some(i) = self.log {
|
||||
args.push(format!("--log={}", i));
|
||||
}
|
||||
|
||||
for extension in &self.install_extensions {
|
||||
args.push(format!("--install-extension={}", extension));
|
||||
}
|
||||
if !&self.install_extensions.is_empty() {
|
||||
if self.pre_release {
|
||||
args.push(String::from("--pre-release"));
|
||||
}
|
||||
if self.force {
|
||||
args.push(String::from("--force"));
|
||||
}
|
||||
}
|
||||
for extension in &self.uninstall_extensions {
|
||||
args.push(format!("--uninstall-extension={}", extension));
|
||||
}
|
||||
if self.list_extensions {
|
||||
args.push(String::from("--list-extensions"));
|
||||
if self.show_versions {
|
||||
args.push(String::from("--show-versions"));
|
||||
}
|
||||
if let Some(i) = &self.category {
|
||||
args.push(format!("--category={}", i));
|
||||
}
|
||||
}
|
||||
if self.start_server {
|
||||
args.push(String::from("--start-server"));
|
||||
}
|
||||
args
|
||||
}
|
||||
}
|
||||
|
||||
/// Base server params that can be `resolve()`d to a `ResolvedServerParams`.
|
||||
/// Doing so fetches additional information like a commit ID if previously
|
||||
/// unspecified.
|
||||
pub struct ServerParamsRaw {
|
||||
pub commit_id: Option<String>,
|
||||
pub quality: Quality,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
pub headless: bool,
|
||||
pub platform: Platform,
|
||||
}
|
||||
|
||||
/// Server params that can be used to start a VS Code server.
|
||||
pub struct ResolvedServerParams {
|
||||
pub release: Release,
|
||||
pub code_server_args: CodeServerArgs,
|
||||
}
|
||||
|
||||
impl ResolvedServerParams {
|
||||
fn as_installed_server(&self) -> InstalledServer {
|
||||
InstalledServer {
|
||||
commit: self.release.commit.clone(),
|
||||
quality: self.release.quality,
|
||||
headless: self.release.target == TargetKind::Server,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParamsRaw {
|
||||
pub async fn resolve(
|
||||
self,
|
||||
log: &log::Logger,
|
||||
http: BoxedHttp,
|
||||
) -> Result<ResolvedServerParams, AnyError> {
|
||||
Ok(ResolvedServerParams {
|
||||
release: self.get_or_fetch_commit_id(log, http).await?,
|
||||
code_server_args: self.code_server_args,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_or_fetch_commit_id(
|
||||
&self,
|
||||
log: &log::Logger,
|
||||
http: BoxedHttp,
|
||||
) -> Result<Release, AnyError> {
|
||||
let target = match self.headless {
|
||||
true => TargetKind::Server,
|
||||
false => TargetKind::Web,
|
||||
};
|
||||
|
||||
if let Some(c) = &self.commit_id {
|
||||
return Ok(Release {
|
||||
commit: c.clone(),
|
||||
quality: self.quality,
|
||||
target,
|
||||
name: String::new(),
|
||||
platform: self.platform,
|
||||
});
|
||||
}
|
||||
|
||||
UpdateService::new(log.clone(), http)
|
||||
.get_latest_commit(self.platform, target, self.quality)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[allow(dead_code)]
|
||||
struct UpdateServerVersion {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub product_version: String,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
/// Code server listening on a port address.
|
||||
#[derive(Clone)]
|
||||
pub struct SocketCodeServer {
|
||||
pub commit_id: String,
|
||||
pub socket: PathBuf,
|
||||
pub origin: Arc<CodeServerOrigin>,
|
||||
}
|
||||
|
||||
/// Code server listening on a socket address.
|
||||
#[derive(Clone)]
|
||||
pub struct PortCodeServer {
|
||||
pub commit_id: String,
|
||||
pub port: u16,
|
||||
pub origin: Arc<CodeServerOrigin>,
|
||||
}
|
||||
|
||||
/// A server listening on any address/location.
|
||||
pub enum AnyCodeServer {
|
||||
Socket(SocketCodeServer),
|
||||
Port(PortCodeServer),
|
||||
}
|
||||
|
||||
pub enum CodeServerOrigin {
|
||||
/// A new code server, that opens the barrier when it exits.
|
||||
New(Box<Child>),
|
||||
/// An existing code server with a PID.
|
||||
Existing(u32),
|
||||
}
|
||||
|
||||
impl CodeServerOrigin {
|
||||
pub async fn wait_for_exit(&mut self) {
|
||||
match self {
|
||||
CodeServerOrigin::New(child) => {
|
||||
child.wait().await.ok();
|
||||
}
|
||||
CodeServerOrigin::Existing(pid) => {
|
||||
let mut interval = interval(Duration::from_secs(30));
|
||||
while process_exists(*pid) {
|
||||
interval.tick().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn kill(&mut self) {
|
||||
match self {
|
||||
CodeServerOrigin::New(child) => {
|
||||
child.kill().await.ok();
|
||||
}
|
||||
CodeServerOrigin::Existing(pid) => {
|
||||
kill_tree(*pid).await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures the given list of extensions are installed on the running server.
|
||||
async fn do_extension_install_on_running_server(
|
||||
start_script_path: &Path,
|
||||
extensions: &[String],
|
||||
log: &log::Logger,
|
||||
) -> Result<(), AnyError> {
|
||||
if extensions.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
debug!(log, "Installing extensions...");
|
||||
let command = format!(
|
||||
"{} {}",
|
||||
start_script_path.display(),
|
||||
extensions
|
||||
.iter()
|
||||
.map(|s| get_extensions_flag(s))
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
);
|
||||
|
||||
let result = capture_command("bash", &["-c", &command]).await?;
|
||||
if !result.status.success() {
|
||||
Err(AnyError::from(ExtensionInstallFailed(
|
||||
String::from_utf8_lossy(&result.stderr).to_string(),
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerBuilder<'a> {
|
||||
logger: &'a log::Logger,
|
||||
server_params: &'a ResolvedServerParams,
|
||||
launcher_paths: &'a LauncherPaths,
|
||||
server_paths: ServerPaths,
|
||||
http: BoxedHttp,
|
||||
}
|
||||
|
||||
impl<'a> ServerBuilder<'a> {
|
||||
pub fn new(
|
||||
logger: &'a log::Logger,
|
||||
server_params: &'a ResolvedServerParams,
|
||||
launcher_paths: &'a LauncherPaths,
|
||||
http: BoxedHttp,
|
||||
) -> Self {
|
||||
Self {
|
||||
logger,
|
||||
server_params,
|
||||
launcher_paths,
|
||||
server_paths: server_params
|
||||
.as_installed_server()
|
||||
.server_paths(launcher_paths),
|
||||
http,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets any already-running server from this directory.
|
||||
pub async fn get_running(&self) -> Result<Option<AnyCodeServer>, AnyError> {
|
||||
info!(
|
||||
self.logger,
|
||||
"Checking {} and {} for a running server...",
|
||||
self.server_paths.logfile.display(),
|
||||
self.server_paths.pidfile.display()
|
||||
);
|
||||
|
||||
let pid = match self.server_paths.get_running_pid() {
|
||||
Some(pid) => pid,
|
||||
None => return Ok(None),
|
||||
};
|
||||
info!(self.logger, "Found running server (pid={})", pid);
|
||||
if !Path::new(&self.server_paths.logfile).exists() {
|
||||
warning!(self.logger, "{} Server is running but its logfile is missing. Don't delete the {} Server manually, run the command '{} prune'.", QUALITYLESS_PRODUCT_NAME, QUALITYLESS_PRODUCT_NAME, APPLICATION_NAME);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
do_extension_install_on_running_server(
|
||||
&self.server_paths.executable,
|
||||
&self.server_params.code_server_args.install_extensions,
|
||||
self.logger,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let origin = Arc::new(CodeServerOrigin::Existing(pid));
|
||||
let contents = fs::read_to_string(&self.server_paths.logfile)
|
||||
.expect("Something went wrong reading log file");
|
||||
|
||||
if let Some(port) = parse_port_from(&contents) {
|
||||
Ok(Some(AnyCodeServer::Port(PortCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
port,
|
||||
origin,
|
||||
})))
|
||||
} else if let Some(socket) = parse_socket_from(&contents) {
|
||||
Ok(Some(AnyCodeServer::Socket(SocketCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
socket,
|
||||
origin,
|
||||
})))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures the server is set up in the configured directory.
|
||||
pub async fn setup(&self) -> Result<(), AnyError> {
|
||||
debug!(
|
||||
self.logger,
|
||||
"Installing and setting up {}...", QUALITYLESS_SERVER_NAME
|
||||
);
|
||||
|
||||
let update_service = UpdateService::new(self.logger.clone(), self.http.clone());
|
||||
let name = get_server_folder_name(
|
||||
self.server_params.release.quality,
|
||||
&self.server_params.release.commit,
|
||||
);
|
||||
|
||||
self.launcher_paths
|
||||
.server_cache
|
||||
.create(name, |target_dir| async move {
|
||||
let tmpdir =
|
||||
tempfile::tempdir().map_err(|e| wrap(e, "error creating temp download dir"))?;
|
||||
|
||||
let response = update_service
|
||||
.get_download_stream(&self.server_params.release)
|
||||
.await?;
|
||||
let archive_path = tmpdir.path().join(response.url_path_basename().unwrap());
|
||||
|
||||
info!(
|
||||
self.logger,
|
||||
"Downloading {} server -> {}",
|
||||
QUALITYLESS_PRODUCT_NAME,
|
||||
archive_path.display()
|
||||
);
|
||||
|
||||
http::download_into_file(
|
||||
&archive_path,
|
||||
self.logger.get_download_logger("server download progress:"),
|
||||
response,
|
||||
)
|
||||
.await?;
|
||||
|
||||
unzip_downloaded_release(
|
||||
&archive_path,
|
||||
&target_dir.join(SERVER_FOLDER_NAME),
|
||||
SilentCopyProgress(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
debug!(self.logger, "Server setup complete");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn listen_on_port(&self, port: u16) -> Result<PortCodeServer, AnyError> {
|
||||
let mut cmd = self.get_base_command();
|
||||
cmd.arg("--start-server")
|
||||
.arg("--enable-remote-auto-shutdown")
|
||||
.arg(format!("--port={}", port));
|
||||
|
||||
let child = self.spawn_server_process(cmd)?;
|
||||
let log_file = self.get_logfile()?;
|
||||
let plog = self.logger.prefixed(&log::new_code_server_prefix());
|
||||
|
||||
let (mut origin, listen_rx) =
|
||||
monitor_server::<PortMatcher, u16>(child, Some(log_file), plog, false);
|
||||
|
||||
let port = match timeout(Duration::from_secs(8), listen_rx).await {
|
||||
Err(e) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "timed out looking for port"))
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "server exited without writing port"))
|
||||
}
|
||||
Ok(Ok(p)) => Ok(p),
|
||||
}?;
|
||||
|
||||
info!(self.logger, "Server started");
|
||||
|
||||
Ok(PortCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
port,
|
||||
origin: Arc::new(origin),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn listen_on_default_socket(&self) -> Result<SocketCodeServer, AnyError> {
|
||||
let requested_file = get_socket_name();
|
||||
self.listen_on_socket(&requested_file).await
|
||||
}
|
||||
|
||||
pub async fn listen_on_socket(&self, socket: &Path) -> Result<SocketCodeServer, AnyError> {
|
||||
Ok(spanf!(
|
||||
self.logger,
|
||||
self.logger.span("server.start").with_attributes(vec! {
|
||||
KeyValue::new("commit_id", self.server_params.release.commit.to_string()),
|
||||
KeyValue::new("quality", format!("{}", self.server_params.release.quality)),
|
||||
}),
|
||||
self._listen_on_socket(socket)
|
||||
)?)
|
||||
}
|
||||
|
||||
async fn _listen_on_socket(&self, socket: &Path) -> Result<SocketCodeServer, AnyError> {
|
||||
remove_file(&socket).await.ok(); // ignore any error if it doesn't exist
|
||||
|
||||
let mut cmd = self.get_base_command();
|
||||
cmd.arg("--start-server")
|
||||
.arg("--enable-remote-auto-shutdown")
|
||||
.arg(format!("--socket-path={}", socket.display()));
|
||||
|
||||
let child = self.spawn_server_process(cmd)?;
|
||||
let log_file = self.get_logfile()?;
|
||||
let plog = self.logger.prefixed(&log::new_code_server_prefix());
|
||||
|
||||
let (mut origin, listen_rx) =
|
||||
monitor_server::<SocketMatcher, PathBuf>(child, Some(log_file), plog, false);
|
||||
|
||||
let socket = match timeout(Duration::from_secs(8), listen_rx).await {
|
||||
Err(e) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "timed out looking for socket"))
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
origin.kill().await;
|
||||
Err(wrap(e, "server exited without writing socket"))
|
||||
}
|
||||
Ok(Ok(socket)) => Ok(socket),
|
||||
}?;
|
||||
|
||||
info!(self.logger, "Server started");
|
||||
|
||||
Ok(SocketCodeServer {
|
||||
commit_id: self.server_params.release.commit.to_owned(),
|
||||
socket,
|
||||
origin: Arc::new(origin),
|
||||
})
|
||||
}
|
||||
|
||||
/// Starts with a given opaque set of args. Does not set up any port or
|
||||
/// socket, but does return one if present, in the form of a channel.
|
||||
pub async fn start_opaque_with_args<M, R>(
|
||||
&self,
|
||||
args: &[String],
|
||||
) -> Result<(CodeServerOrigin, Receiver<R>), AnyError>
|
||||
where
|
||||
M: ServerOutputMatcher<R>,
|
||||
R: 'static + Send + std::fmt::Debug,
|
||||
{
|
||||
let mut cmd = self.get_base_command();
|
||||
cmd.args(args);
|
||||
|
||||
let child = self.spawn_server_process(cmd)?;
|
||||
let plog = self.logger.prefixed(&log::new_code_server_prefix());
|
||||
|
||||
Ok(monitor_server::<M, R>(child, None, plog, true))
|
||||
}
|
||||
|
||||
fn spawn_server_process(&self, mut cmd: Command) -> Result<Child, AnyError> {
|
||||
info!(self.logger, "Starting server...");
|
||||
|
||||
debug!(self.logger, "Starting server with command... {:?}", cmd);
|
||||
|
||||
let child = cmd
|
||||
.stderr(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.map_err(|e| wrap(e, "error spawning server"))?;
|
||||
|
||||
self.server_paths
|
||||
.write_pid(child.id().expect("expected server to have pid"))?;
|
||||
|
||||
Ok(child)
|
||||
}
|
||||
|
||||
fn get_logfile(&self) -> Result<File, WrappedError> {
|
||||
File::create(&self.server_paths.logfile).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!(
|
||||
"error creating log file {}",
|
||||
self.server_paths.logfile.display()
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn get_base_command(&self) -> Command {
|
||||
let mut cmd = Command::new(&self.server_paths.executable);
|
||||
cmd.stdin(std::process::Stdio::null())
|
||||
.args(self.server_params.code_server_args.command_arguments());
|
||||
cmd
|
||||
}
|
||||
}
|
||||
|
||||
fn monitor_server<M, R>(
|
||||
mut child: Child,
|
||||
log_file: Option<File>,
|
||||
plog: log::Logger,
|
||||
write_directly: bool,
|
||||
) -> (CodeServerOrigin, Receiver<R>)
|
||||
where
|
||||
M: ServerOutputMatcher<R>,
|
||||
R: 'static + Send + std::fmt::Debug,
|
||||
{
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.expect("child did not have a handle to stdout");
|
||||
|
||||
let stderr = child
|
||||
.stderr
|
||||
.take()
|
||||
.expect("child did not have a handle to stdout");
|
||||
|
||||
let (listen_tx, listen_rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
// Handle stderr and stdout in a separate task. Initially scan lines looking
|
||||
// for the listening port. Afterwards, just scan and write out to the file.
|
||||
tokio::spawn(async move {
|
||||
let mut stdout_reader = BufReader::new(stdout).lines();
|
||||
let mut stderr_reader = BufReader::new(stderr).lines();
|
||||
let write_line = |line: &str| -> std::io::Result<()> {
|
||||
if let Some(mut f) = log_file.as_ref() {
|
||||
f.write_all(line.as_bytes())?;
|
||||
f.write_all(&[b'\n'])?;
|
||||
}
|
||||
if write_directly {
|
||||
println!("{}", line);
|
||||
} else {
|
||||
trace!(plog, line);
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
loop {
|
||||
let line = tokio::select! {
|
||||
l = stderr_reader.next_line() => l,
|
||||
l = stdout_reader.next_line() => l,
|
||||
};
|
||||
|
||||
match line {
|
||||
Err(e) => {
|
||||
trace!(plog, "error reading from stdout/stderr: {}", e);
|
||||
return;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Ok(Some(l)) => {
|
||||
write_line(&l).ok();
|
||||
|
||||
if let Some(listen_on) = M::match_line(&l) {
|
||||
trace!(plog, "parsed location: {:?}", listen_on);
|
||||
listen_tx.send(listen_on).ok();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let line = tokio::select! {
|
||||
l = stderr_reader.next_line() => l,
|
||||
l = stdout_reader.next_line() => l,
|
||||
};
|
||||
|
||||
match line {
|
||||
Err(e) => {
|
||||
trace!(plog, "error reading from stdout/stderr: {}", e);
|
||||
break;
|
||||
}
|
||||
Ok(None) => break,
|
||||
Ok(Some(l)) => {
|
||||
write_line(&l).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let origin = CodeServerOrigin::New(Box::new(child));
|
||||
(origin, listen_rx)
|
||||
}
|
||||
|
||||
fn get_extensions_flag(extension_id: &str) -> String {
|
||||
format!("--install-extension={}", extension_id)
|
||||
}
|
||||
|
||||
/// A type that can be used to scan stdout from the VS Code server. Returns
|
||||
/// some other type that, in turn, is returned from starting the server.
|
||||
pub trait ServerOutputMatcher<R>
|
||||
where
|
||||
R: Send,
|
||||
{
|
||||
fn match_line(line: &str) -> Option<R>;
|
||||
}
|
||||
|
||||
/// Parses a line like "Extension host agent listening on /tmp/foo.sock"
|
||||
struct SocketMatcher();
|
||||
|
||||
impl ServerOutputMatcher<PathBuf> for SocketMatcher {
|
||||
fn match_line(line: &str) -> Option<PathBuf> {
|
||||
parse_socket_from(line)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses a line like "Extension host agent listening on 9000"
|
||||
pub struct PortMatcher();
|
||||
|
||||
impl ServerOutputMatcher<u16> for PortMatcher {
|
||||
fn match_line(line: &str) -> Option<u16> {
|
||||
parse_port_from(line)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses a line like "Web UI available at http://localhost:9000/?tkn=..."
|
||||
pub struct WebUiMatcher();
|
||||
|
||||
impl ServerOutputMatcher<reqwest::Url> for WebUiMatcher {
|
||||
fn match_line(line: &str) -> Option<reqwest::Url> {
|
||||
WEB_UI_RE.captures(line).and_then(|cap| {
|
||||
cap.get(1)
|
||||
.and_then(|uri| reqwest::Url::parse(uri.as_str()).ok())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Does not do any parsing and just immediately returns an empty result.
|
||||
pub struct NoOpMatcher();
|
||||
|
||||
impl ServerOutputMatcher<()> for NoOpMatcher {
|
||||
fn match_line(_: &str) -> Option<()> {
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_socket_from(text: &str) -> Option<PathBuf> {
|
||||
LISTENING_PORT_RE
|
||||
.captures(text)
|
||||
.and_then(|cap| cap.get(1).map(|path| PathBuf::from(path.as_str())))
|
||||
}
|
||||
|
||||
fn parse_port_from(text: &str) -> Option<u16> {
|
||||
LISTENING_PORT_RE.captures(text).and_then(|cap| {
|
||||
cap.get(1)
|
||||
.and_then(|path| path.as_str().parse::<u16>().ok())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn print_listening(log: &log::Logger, tunnel_name: &str) {
|
||||
debug!(
|
||||
log,
|
||||
"{} is listening for incoming connections", QUALITYLESS_SERVER_NAME
|
||||
);
|
||||
|
||||
let home_dir = dirs::home_dir().unwrap_or_else(|| PathBuf::from(""));
|
||||
let current_dir = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(""));
|
||||
|
||||
let dir = if home_dir == current_dir {
|
||||
PathBuf::from("")
|
||||
} else {
|
||||
current_dir
|
||||
};
|
||||
|
||||
let base_web_url = match EDITOR_WEB_URL {
|
||||
Some(u) => u,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let mut addr = url::Url::parse(base_web_url).unwrap();
|
||||
{
|
||||
let mut ps = addr.path_segments_mut().unwrap();
|
||||
ps.push("tunnel");
|
||||
ps.push(tunnel_name);
|
||||
for segment in &dir {
|
||||
let as_str = segment.to_string_lossy();
|
||||
if !(as_str.len() == 1 && as_str.starts_with(std::path::MAIN_SEPARATOR)) {
|
||||
ps.push(as_str.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let message = &format!("\nOpen this link in your browser {}\n", addr);
|
||||
log.result(message);
|
||||
}
|
||||
|
||||
pub async fn download_cli_into_cache(
|
||||
cache: &DownloadCache,
|
||||
release: &Release,
|
||||
update_service: &UpdateService,
|
||||
) -> Result<PathBuf, AnyError> {
|
||||
let cache_name = format!(
|
||||
"{}-{}-{}",
|
||||
release.quality, release.commit, release.platform
|
||||
);
|
||||
let cli_dir = cache
|
||||
.create(&cache_name, |target_dir| async move {
|
||||
let tmpdir =
|
||||
tempfile::tempdir().map_err(|e| wrap(e, "error creating temp download dir"))?;
|
||||
let response = update_service.get_download_stream(release).await?;
|
||||
|
||||
let name = response.url_path_basename().unwrap();
|
||||
let archive_path = tmpdir.path().join(name);
|
||||
http::download_into_file(&archive_path, SilentCopyProgress(), response).await?;
|
||||
unzip_downloaded_release(&archive_path, &target_dir, SilentCopyProgress())?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
let cli = std::fs::read_dir(cli_dir)
|
||||
.map_err(|_| CodeError::CorruptDownload("could not read cli folder contents"))?
|
||||
.next();
|
||||
|
||||
match cli {
|
||||
Some(Ok(cli)) => Ok(cli.path()),
|
||||
_ => {
|
||||
let _ = cache.delete(&cache_name);
|
||||
Err(CodeError::CorruptDownload("cli directory is empty").into())
|
||||
}
|
||||
}
|
||||
}
|
||||
1168
cli/src/tunnels/control_server.rs
Normal file
1168
cli/src/tunnels/control_server.rs
Normal file
File diff suppressed because it is too large
Load Diff
997
cli/src/tunnels/dev_tunnels.rs
Normal file
997
cli/src/tunnels/dev_tunnels.rs
Normal file
@@ -0,0 +1,997 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::auth;
|
||||
use crate::constants::{
|
||||
CONTROL_PORT, IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, PROTOCOL_VERSION_TAG_PREFIX,
|
||||
TUNNEL_SERVICE_USER_AGENT,
|
||||
};
|
||||
use crate::state::{LauncherPaths, PersistedState};
|
||||
use crate::util::errors::{
|
||||
wrap, AnyError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed, WrappedError,
|
||||
};
|
||||
use crate::util::input::prompt_placeholder;
|
||||
use crate::{debug, info, log, spanf, trace, warning};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use lazy_static::lazy_static;
|
||||
use rand::prelude::IteratorRandom;
|
||||
use regex::Regex;
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tunnels::connections::{ForwardedPortConnection, RelayTunnelHost};
|
||||
use tunnels::contracts::{
|
||||
Tunnel, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN, TUNNEL_PROTOCOL_AUTO,
|
||||
};
|
||||
use tunnels::management::{
|
||||
new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
|
||||
NO_REQUEST_OPTIONS,
|
||||
};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PersistedTunnel {
|
||||
pub name: String,
|
||||
pub id: String,
|
||||
pub cluster: String,
|
||||
}
|
||||
|
||||
impl PersistedTunnel {
|
||||
pub fn into_locator(self) -> TunnelLocator {
|
||||
TunnelLocator::ID {
|
||||
cluster: self.cluster,
|
||||
id: self.id,
|
||||
}
|
||||
}
|
||||
pub fn locator(&self) -> TunnelLocator {
|
||||
TunnelLocator::ID {
|
||||
cluster: self.cluster.clone(),
|
||||
id: self.id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
trait AccessTokenProvider: Send + Sync {
|
||||
/// Gets the current access token.
|
||||
async fn refresh_token(&self) -> Result<String, WrappedError>;
|
||||
}
|
||||
|
||||
/// Access token provider that provides a fixed token without refreshing.
|
||||
struct StaticAccessTokenProvider(String);
|
||||
|
||||
impl StaticAccessTokenProvider {
|
||||
pub fn new(token: String) -> Self {
|
||||
Self(token)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AccessTokenProvider for StaticAccessTokenProvider {
|
||||
async fn refresh_token(&self) -> Result<String, WrappedError> {
|
||||
Ok(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Access token provider that looks up the token from the tunnels API.
|
||||
struct LookupAccessTokenProvider {
|
||||
client: TunnelManagementClient,
|
||||
locator: TunnelLocator,
|
||||
log: log::Logger,
|
||||
initial_token: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
impl LookupAccessTokenProvider {
|
||||
pub fn new(
|
||||
client: TunnelManagementClient,
|
||||
locator: TunnelLocator,
|
||||
log: log::Logger,
|
||||
initial_token: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
locator,
|
||||
log,
|
||||
initial_token: Arc::new(Mutex::new(initial_token)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AccessTokenProvider for LookupAccessTokenProvider {
|
||||
async fn refresh_token(&self) -> Result<String, WrappedError> {
|
||||
if let Some(token) = self.initial_token.lock().unwrap().take() {
|
||||
return Ok(token);
|
||||
}
|
||||
|
||||
let tunnel_lookup = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.tag.get"),
|
||||
self.client.get_tunnel(
|
||||
&self.locator,
|
||||
&TunnelRequestOptions {
|
||||
token_scopes: vec!["host".to_string()],
|
||||
..Default::default()
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
trace!(self.log, "Successfully refreshed access token");
|
||||
|
||||
match tunnel_lookup {
|
||||
Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)),
|
||||
Err(e) => Err(wrap(e, "failed to lookup tunnel for host token")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DevTunnels {
|
||||
log: log::Logger,
|
||||
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
|
||||
client: TunnelManagementClient,
|
||||
}
|
||||
|
||||
/// Representation of a tunnel returned from the `start` methods.
|
||||
pub struct ActiveTunnel {
|
||||
/// Name of the tunnel
|
||||
pub name: String,
|
||||
/// Underlying dev tunnels ID
|
||||
pub id: String,
|
||||
manager: ActiveTunnelManager,
|
||||
}
|
||||
|
||||
impl ActiveTunnel {
|
||||
/// Closes and unregisters the tunnel.
|
||||
pub async fn close(&mut self) -> Result<(), AnyError> {
|
||||
self.manager.kill().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Forwards a port to local connections.
|
||||
pub async fn add_port_direct(
|
||||
&mut self,
|
||||
port_number: u16,
|
||||
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, AnyError> {
|
||||
let port = self.manager.add_port_direct(port_number).await?;
|
||||
Ok(port)
|
||||
}
|
||||
|
||||
/// Forwards a port over TCP.
|
||||
pub async fn add_port_tcp(&mut self, port_number: u16) -> Result<(), AnyError> {
|
||||
self.manager.add_port_tcp(port_number).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a forwarded port TCP.
|
||||
pub async fn remove_port(&mut self, port_number: u16) -> Result<(), AnyError> {
|
||||
self.manager.remove_port(port_number).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Gets the public URI on which a forwarded port can be access in browser.
|
||||
pub async fn get_port_uri(&mut self, port: u16) -> Result<String, AnyError> {
|
||||
let endpoint = self.manager.get_endpoint().await?;
|
||||
let format = endpoint
|
||||
.base
|
||||
.port_uri_format
|
||||
.expect("expected to have port format");
|
||||
|
||||
Ok(format.replace(PORT_TOKEN, &port.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
|
||||
const MAX_TUNNEL_NAME_LENGTH: usize = 20;
|
||||
|
||||
fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
|
||||
tunnel
|
||||
.access_tokens
|
||||
.as_ref()
|
||||
.expect("expected to have access tokens")
|
||||
.get("host")
|
||||
.expect("expected to have host token")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> {
|
||||
if name.len() > MAX_TUNNEL_NAME_LENGTH {
|
||||
return Err(InvalidTunnelName(format!(
|
||||
"Names cannot be longer than {} characters. Please try a different name.",
|
||||
MAX_TUNNEL_NAME_LENGTH
|
||||
)));
|
||||
}
|
||||
|
||||
let re = Regex::new(r"^([\w-]+)$").unwrap();
|
||||
|
||||
if !re.is_match(name) {
|
||||
return Err(InvalidTunnelName(
|
||||
"Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref HOST_TUNNEL_REQUEST_OPTIONS: TunnelRequestOptions = TunnelRequestOptions {
|
||||
include_ports: true,
|
||||
token_scopes: vec!["host".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
}
|
||||
|
||||
/// Structure optionally passed into `start_existing_tunnel` to forward an existing tunnel.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExistingTunnel {
|
||||
/// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server
|
||||
pub tunnel_name: String,
|
||||
|
||||
/// Token to authenticate and use preexisting tunnel
|
||||
pub host_token: String,
|
||||
|
||||
/// Id of preexisting tunnel to use to connect to the VS Code Server
|
||||
pub tunnel_id: String,
|
||||
|
||||
/// Cluster of preexisting tunnel to use to connect to the VS Code Server
|
||||
pub cluster: String,
|
||||
}
|
||||
|
||||
impl DevTunnels {
|
||||
pub fn new(log: &log::Logger, auth: auth::Auth, paths: &LauncherPaths) -> DevTunnels {
|
||||
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
|
||||
client.authorization_provider(auth);
|
||||
|
||||
DevTunnels {
|
||||
log: log.clone(),
|
||||
client: client.into(),
|
||||
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> {
|
||||
let tunnel = match self.launcher_tunnel.load() {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.delete"),
|
||||
self.client
|
||||
.delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS)
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
|
||||
|
||||
self.launcher_tunnel.save(None)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Renames the current tunnel to the new name.
|
||||
pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
|
||||
self.update_tunnel_name(None, name).await.map(|_| ())
|
||||
}
|
||||
|
||||
/// Updates the name of the existing persisted tunnel to the new name.
|
||||
/// Gracefully creates a new tunnel if the previous one was deleted.
|
||||
async fn update_tunnel_name(
|
||||
&mut self,
|
||||
persisted: Option<PersistedTunnel>,
|
||||
name: &str,
|
||||
) -> Result<(Tunnel, PersistedTunnel), AnyError> {
|
||||
let name = name.to_ascii_lowercase();
|
||||
self.check_is_name_free(&name).await?;
|
||||
|
||||
debug!(self.log, "Tunnel name changed, applying updates...");
|
||||
|
||||
let (mut full_tunnel, mut persisted, is_new) = match persisted {
|
||||
Some(persisted) => {
|
||||
self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
|
||||
.await
|
||||
}
|
||||
None => self
|
||||
.create_tunnel(&name, NO_REQUEST_OPTIONS)
|
||||
.await
|
||||
.map(|(pt, t)| (t, pt, true)),
|
||||
}?;
|
||||
|
||||
if is_new {
|
||||
return Ok((full_tunnel, persisted));
|
||||
}
|
||||
|
||||
full_tunnel.tags = vec![name.to_string(), VSCODE_CLI_TUNNEL_TAG.to_string()];
|
||||
|
||||
let new_tunnel = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.tag.update"),
|
||||
self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to rename tunnel"))?;
|
||||
|
||||
persisted.name = name;
|
||||
self.launcher_tunnel.save(Some(persisted.clone()))?;
|
||||
|
||||
Ok((new_tunnel, persisted))
|
||||
}
|
||||
|
||||
/// Gets the persisted tunnel from the service, or creates a new one.
|
||||
/// If `create_with_new_name` is given, the new tunnel has that name
|
||||
/// instead of the one previously persisted.
|
||||
async fn get_or_create_tunnel(
|
||||
&mut self,
|
||||
persisted: PersistedTunnel,
|
||||
create_with_new_name: Option<&str>,
|
||||
options: &TunnelRequestOptions,
|
||||
) -> Result<(Tunnel, PersistedTunnel, /* is_new */ bool), AnyError> {
|
||||
let tunnel_lookup = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.tag.get"),
|
||||
self.client.get_tunnel(&persisted.locator(), options)
|
||||
);
|
||||
|
||||
match tunnel_lookup {
|
||||
Ok(ft) => Ok((ft, persisted, false)),
|
||||
Err(HttpError::ResponseError(e))
|
||||
if e.status_code == StatusCode::NOT_FOUND
|
||||
|| e.status_code == StatusCode::FORBIDDEN =>
|
||||
{
|
||||
let (persisted, tunnel) = self
|
||||
.create_tunnel(create_with_new_name.unwrap_or(&persisted.name), options)
|
||||
.await?;
|
||||
Ok((tunnel, persisted, true))
|
||||
}
|
||||
Err(e) => Err(wrap(e, "failed to lookup tunnel").into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts a new tunnel for the code server on the port. Unlike `start_new_tunnel`,
|
||||
/// this attempts to reuse or create a tunnel of a preferred name or of a generated friendly tunnel name.
|
||||
pub async fn start_new_launcher_tunnel(
|
||||
&mut self,
|
||||
preferred_name: Option<&str>,
|
||||
use_random_name: bool,
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
|
||||
Some(mut persisted) => {
|
||||
if let Some(preferred_name) = preferred_name.map(|n| n.to_ascii_lowercase()) {
|
||||
if persisted.name.to_ascii_lowercase() != preferred_name {
|
||||
(_, persisted) = self
|
||||
.update_tunnel_name(Some(persisted), &preferred_name)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
let (tunnel, persisted, _) = self
|
||||
.get_or_create_tunnel(persisted, None, &HOST_TUNNEL_REQUEST_OPTIONS)
|
||||
.await?;
|
||||
(tunnel, persisted)
|
||||
}
|
||||
None => {
|
||||
debug!(self.log, "No code server tunnel found, creating new one");
|
||||
let name = self
|
||||
.get_name_for_tunnel(preferred_name, use_random_name)
|
||||
.await?;
|
||||
let (persisted, full_tunnel) = self
|
||||
.create_tunnel(&name, &HOST_TUNNEL_REQUEST_OPTIONS)
|
||||
.await?;
|
||||
(full_tunnel, persisted)
|
||||
}
|
||||
};
|
||||
|
||||
if !tunnel.tags.iter().any(|t| t == PROTOCOL_VERSION_TAG) {
|
||||
tunnel = self
|
||||
.update_protocol_version_tag(tunnel, &HOST_TUNNEL_REQUEST_OPTIONS)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let locator = TunnelLocator::try_from(&tunnel).unwrap();
|
||||
let host_token = get_host_token_from_tunnel(&tunnel);
|
||||
|
||||
for port_to_delete in tunnel
|
||||
.ports
|
||||
.iter()
|
||||
.filter(|p| p.port_number != CONTROL_PORT)
|
||||
{
|
||||
let output_fut = self.client.delete_tunnel_port(
|
||||
&locator,
|
||||
port_to_delete.port_number,
|
||||
NO_REQUEST_OPTIONS,
|
||||
);
|
||||
spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.port.delete"),
|
||||
output_fut
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to delete port"))?;
|
||||
}
|
||||
|
||||
// cleanup any old trailing tunnel endpoints
|
||||
for endpoint in tunnel.endpoints {
|
||||
let fut = self.client.delete_tunnel_endpoints(
|
||||
&locator,
|
||||
&endpoint.host_id,
|
||||
None,
|
||||
NO_REQUEST_OPTIONS,
|
||||
);
|
||||
|
||||
spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut)
|
||||
.map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?;
|
||||
}
|
||||
|
||||
self.start_tunnel(
|
||||
locator.clone(),
|
||||
&persisted,
|
||||
self.client.clone(),
|
||||
LookupAccessTokenProvider::new(
|
||||
self.client.clone(),
|
||||
locator,
|
||||
self.log.clone(),
|
||||
Some(host_token),
|
||||
),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn create_tunnel(
|
||||
&mut self,
|
||||
name: &str,
|
||||
options: &TunnelRequestOptions,
|
||||
) -> Result<(PersistedTunnel, Tunnel), AnyError> {
|
||||
info!(self.log, "Creating tunnel with the name: {}", name);
|
||||
|
||||
let mut tried_recycle = false;
|
||||
|
||||
let new_tunnel = Tunnel {
|
||||
tags: vec![
|
||||
name.to_string(),
|
||||
PROTOCOL_VERSION_TAG.to_string(),
|
||||
VSCODE_CLI_TUNNEL_TAG.to_string(),
|
||||
],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
loop {
|
||||
let result = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.create"),
|
||||
self.client.create_tunnel(&new_tunnel, options)
|
||||
);
|
||||
|
||||
match result {
|
||||
Err(HttpError::ResponseError(e))
|
||||
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
|
||||
{
|
||||
if !tried_recycle && self.try_recycle_tunnel().await? {
|
||||
tried_recycle = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
return Err(AnyError::from(TunnelCreationFailed(
|
||||
name.to_string(),
|
||||
"You've exceeded the 10 machine limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(),
|
||||
)));
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(AnyError::from(TunnelCreationFailed(
|
||||
name.to_string(),
|
||||
format!("{:?}", e),
|
||||
)))
|
||||
}
|
||||
Ok(t) => {
|
||||
let pt = PersistedTunnel {
|
||||
cluster: t.cluster_id.clone().unwrap(),
|
||||
id: t.tunnel_id.clone().unwrap(),
|
||||
name: name.to_string(),
|
||||
};
|
||||
|
||||
self.launcher_tunnel.save(Some(pt.clone()))?;
|
||||
return Ok((pt, t));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensures the tunnel contains a tag for the current PROTCOL_VERSION, and no
|
||||
/// other version tags.
|
||||
async fn update_protocol_version_tag(
|
||||
&self,
|
||||
tunnel: Tunnel,
|
||||
options: &TunnelRequestOptions,
|
||||
) -> Result<Tunnel, AnyError> {
|
||||
debug!(
|
||||
self.log,
|
||||
"Updating tunnel protocol version tag to {}", PROTOCOL_VERSION_TAG
|
||||
);
|
||||
let mut new_tags: Vec<String> = tunnel
|
||||
.tags
|
||||
.into_iter()
|
||||
.filter(|t| !t.starts_with(PROTOCOL_VERSION_TAG_PREFIX))
|
||||
.collect();
|
||||
new_tags.push(PROTOCOL_VERSION_TAG.to_string());
|
||||
|
||||
let tunnel_update = Tunnel {
|
||||
tags: new_tags,
|
||||
tunnel_id: tunnel.tunnel_id.clone(),
|
||||
cluster_id: tunnel.cluster_id.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.protocol-tag-update"),
|
||||
self.client.update_tunnel(&tunnel_update, options)
|
||||
);
|
||||
|
||||
result.map_err(|e| wrap(e, "tunnel tag update failed").into())
|
||||
}
|
||||
|
||||
/// Tries to delete an unused tunnel, and then creates a tunnel with the
|
||||
/// given `new_name`.
|
||||
async fn try_recycle_tunnel(&mut self) -> Result<bool, AnyError> {
|
||||
trace!(
|
||||
self.log,
|
||||
"Tunnel limit hit, trying to recycle an old tunnel"
|
||||
);
|
||||
|
||||
let existing_tunnels = self.list_all_server_tunnels().await?;
|
||||
|
||||
let recyclable = existing_tunnels
|
||||
.iter()
|
||||
.filter(|t| {
|
||||
t.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.host_connection_count.as_ref())
|
||||
.map(|c| c.get_count())
|
||||
.unwrap_or(0) == 0
|
||||
})
|
||||
.choose(&mut rand::thread_rng());
|
||||
|
||||
match recyclable {
|
||||
Some(tunnel) => {
|
||||
trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id);
|
||||
spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.delete"),
|
||||
self.client
|
||||
.delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS)
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
|
||||
Ok(true)
|
||||
}
|
||||
None => {
|
||||
trace!(self.log, "No tunnels available to recycle");
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_all_server_tunnels(&mut self) -> Result<Vec<Tunnel>, AnyError> {
|
||||
let tunnels = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.listall"),
|
||||
self.client.list_all_tunnels(&TunnelRequestOptions {
|
||||
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string()],
|
||||
require_all_tags: true,
|
||||
..Default::default()
|
||||
})
|
||||
)
|
||||
.map_err(|e| wrap(e, "error listing current tunnels"))?;
|
||||
|
||||
Ok(tunnels)
|
||||
}
|
||||
|
||||
async fn check_is_name_free(&mut self, name: &str) -> Result<(), AnyError> {
|
||||
let existing = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.rename.search"),
|
||||
self.client.list_all_tunnels(&TunnelRequestOptions {
|
||||
tags: vec![VSCODE_CLI_TUNNEL_TAG.to_string(), name.to_string()],
|
||||
require_all_tags: true,
|
||||
..Default::default()
|
||||
})
|
||||
)
|
||||
.map_err(|e| wrap(e, "failed to list existing tunnels"))?;
|
||||
if !existing.is_empty() {
|
||||
return Err(AnyError::from(TunnelCreationFailed(
|
||||
name.to_string(),
|
||||
"tunnel name already in use".to_string(),
|
||||
)));
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_name_for_tunnel(
|
||||
&mut self,
|
||||
preferred_name: Option<&str>,
|
||||
mut use_random_name: bool,
|
||||
) -> Result<String, AnyError> {
|
||||
let existing_tunnels = self.list_all_server_tunnels().await?;
|
||||
let is_name_free = |n: &str| {
|
||||
!existing_tunnels.iter().any(|v| {
|
||||
v.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.host_connection_count.as_ref().map(|c| c.get_count()))
|
||||
.unwrap_or(0) > 0 && v.tags.iter().any(|t| t == n)
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(machine_name) = preferred_name {
|
||||
let name = machine_name.to_ascii_lowercase();
|
||||
if let Err(e) = is_valid_name(&name) {
|
||||
info!(self.log, "{} is an invalid name", e);
|
||||
return Err(AnyError::from(wrap(e, "invalid name")));
|
||||
}
|
||||
if is_name_free(&name) {
|
||||
return Ok(name);
|
||||
}
|
||||
info!(
|
||||
self.log,
|
||||
"{} is already taken, using a random name instead", &name
|
||||
);
|
||||
use_random_name = true;
|
||||
}
|
||||
|
||||
let mut placeholder_name =
|
||||
clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
|
||||
placeholder_name.make_ascii_lowercase();
|
||||
|
||||
if !is_name_free(&placeholder_name) {
|
||||
for i in 2.. {
|
||||
let fixed_name = format!("{}{}", placeholder_name, i);
|
||||
if is_name_free(&fixed_name) {
|
||||
placeholder_name = fixed_name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if use_random_name || !*IS_INTERACTIVE_CLI {
|
||||
return Ok(placeholder_name);
|
||||
}
|
||||
|
||||
loop {
|
||||
let mut name = prompt_placeholder(
|
||||
"What would you like to call this machine?",
|
||||
&placeholder_name,
|
||||
)?;
|
||||
|
||||
name.make_ascii_lowercase();
|
||||
|
||||
if let Err(e) = is_valid_name(&name) {
|
||||
info!(self.log, "{}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_name_free(&name) {
|
||||
return Ok(name);
|
||||
}
|
||||
|
||||
info!(self.log, "The name {} is already in use", name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Hosts an existing tunnel, where the tunnel ID and host token are given.
|
||||
pub async fn start_existing_tunnel(
|
||||
&mut self,
|
||||
tunnel: ExistingTunnel,
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let tunnel_details = PersistedTunnel {
|
||||
name: tunnel.tunnel_name,
|
||||
id: tunnel.tunnel_id,
|
||||
cluster: tunnel.cluster,
|
||||
};
|
||||
|
||||
let mut mgmt = self.client.build();
|
||||
mgmt.authorization(tunnels::management::Authorization::Tunnel(
|
||||
tunnel.host_token.clone(),
|
||||
));
|
||||
|
||||
self.start_tunnel(
|
||||
tunnel_details.locator(),
|
||||
&tunnel_details,
|
||||
mgmt.into(),
|
||||
StaticAccessTokenProvider::new(tunnel.host_token),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn start_tunnel(
|
||||
&mut self,
|
||||
locator: TunnelLocator,
|
||||
tunnel_details: &PersistedTunnel,
|
||||
client: TunnelManagementClient,
|
||||
access_token: impl AccessTokenProvider + 'static,
|
||||
) -> Result<ActiveTunnel, AnyError> {
|
||||
let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token);
|
||||
|
||||
let endpoint_result = spanf!(
|
||||
self.log,
|
||||
self.log.span("dev-tunnel.serve.callback"),
|
||||
manager.get_endpoint()
|
||||
);
|
||||
|
||||
let endpoint = match endpoint_result {
|
||||
Ok(endpoint) => endpoint,
|
||||
Err(e) => {
|
||||
error!(self.log, "Error connecting to tunnel endpoint: {}", e);
|
||||
manager.kill().await.ok();
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint);
|
||||
|
||||
Ok(ActiveTunnel {
|
||||
name: tunnel_details.name.clone(),
|
||||
id: tunnel_details.id.clone(),
|
||||
manager,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ActiveTunnelManager {
|
||||
close_tx: Option<mpsc::Sender<()>>,
|
||||
endpoint_rx: watch::Receiver<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
|
||||
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
|
||||
}
|
||||
|
||||
impl ActiveTunnelManager {
|
||||
pub fn new(
|
||||
log: log::Logger,
|
||||
mgmt: TunnelManagementClient,
|
||||
locator: TunnelLocator,
|
||||
access_token: impl AccessTokenProvider + 'static,
|
||||
) -> ActiveTunnelManager {
|
||||
let (endpoint_tx, endpoint_rx) = watch::channel(None);
|
||||
let (close_tx, close_rx) = mpsc::channel(1);
|
||||
|
||||
let relay = Arc::new(tokio::sync::Mutex::new(RelayTunnelHost::new(locator, mgmt)));
|
||||
let relay_spawned = relay.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
ActiveTunnelManager::spawn_tunnel(
|
||||
log,
|
||||
relay_spawned,
|
||||
close_rx,
|
||||
endpoint_tx,
|
||||
access_token,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
ActiveTunnelManager {
|
||||
endpoint_rx,
|
||||
relay,
|
||||
close_tx: Some(close_tx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a port for TCP/IP forwarding.
|
||||
#[allow(dead_code)] // todo: port forwarding
|
||||
pub async fn add_port_tcp(&self, port_number: u16) -> Result<(), WrappedError> {
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.add_port(&TunnelPort {
|
||||
port_number,
|
||||
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error adding port to relay"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Adds a port for TCP/IP forwarding.
|
||||
pub async fn add_port_direct(
|
||||
&self,
|
||||
port_number: u16,
|
||||
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, WrappedError> {
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.add_port_raw(&TunnelPort {
|
||||
port_number,
|
||||
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error adding port to relay"))
|
||||
}
|
||||
|
||||
/// Removes a port from TCP/IP forwarding.
|
||||
pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> {
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.remove_port(port_number)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error remove port from relay"))
|
||||
}
|
||||
|
||||
/// Gets the most recent details from the tunnel process. Returns None if
|
||||
/// the process exited before providing details.
|
||||
pub async fn get_endpoint(&mut self) -> Result<TunnelRelayTunnelEndpoint, AnyError> {
|
||||
loop {
|
||||
if let Some(details) = &*self.endpoint_rx.borrow() {
|
||||
return details.clone().map_err(AnyError::from);
|
||||
}
|
||||
|
||||
if self.endpoint_rx.changed().await.is_err() {
|
||||
return Err(DevTunnelError("tunnel creation cancelled".to_string()).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Kills the process, and waits for it to exit.
|
||||
/// See https://tokio.rs/tokio/topics/shutdown#waiting-for-things-to-finish-shutting-down for how this works
|
||||
pub async fn kill(&mut self) -> Result<(), AnyError> {
|
||||
if let Some(tx) = self.close_tx.take() {
|
||||
drop(tx);
|
||||
}
|
||||
|
||||
self.relay
|
||||
.lock()
|
||||
.await
|
||||
.unregister()
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error unregistering relay"))?;
|
||||
|
||||
while self.endpoint_rx.changed().await.is_ok() {}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn spawn_tunnel(
|
||||
log: log::Logger,
|
||||
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
|
||||
mut close_rx: mpsc::Receiver<()>,
|
||||
endpoint_tx: watch::Sender<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
|
||||
access_token_provider: impl AccessTokenProvider + 'static,
|
||||
) {
|
||||
let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));
|
||||
|
||||
macro_rules! fail {
|
||||
($e: expr, $msg: expr) => {
|
||||
warning!(log, "{}: {}", $msg, $e);
|
||||
endpoint_tx.send(Some(Err($e))).ok();
|
||||
backoff.delay().await;
|
||||
};
|
||||
}
|
||||
|
||||
loop {
|
||||
debug!(log, "Starting tunnel to server...");
|
||||
|
||||
let access_token = match access_token_provider.refresh_token().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
fail!(e, "Error refreshing access token, will retry");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// we don't bother making a client that can refresh the token, since
|
||||
// the tunnel won't be able to host as soon as the access token expires.
|
||||
let handle_res = {
|
||||
let mut relay = relay.lock().await;
|
||||
relay
|
||||
.connect(&access_token)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error connecting to tunnel"))
|
||||
};
|
||||
|
||||
let mut handle = match handle_res {
|
||||
Ok(handle) => handle,
|
||||
Err(e) => {
|
||||
fail!(e, "Error connecting to relay, will retry");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff.reset();
|
||||
endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok();
|
||||
|
||||
tokio::select! {
|
||||
// error is mapped like this prevent it being used across an await,
|
||||
// which Rust dislikes since there's a non-sendable dyn Error in there
|
||||
res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => {
|
||||
if let Err(e) = res {
|
||||
fail!(e, "Tunnel exited unexpectedly, reconnecting");
|
||||
} else {
|
||||
warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting");
|
||||
backoff.delay().await;
|
||||
}
|
||||
},
|
||||
_ = close_rx.recv() => {
|
||||
trace!(log, "Tunnel closing gracefully");
|
||||
trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Backoff {
|
||||
failures: u32,
|
||||
base_duration: Duration,
|
||||
max_duration: Duration,
|
||||
}
|
||||
|
||||
impl Backoff {
|
||||
pub fn new(base_duration: Duration, max_duration: Duration) -> Self {
|
||||
Self {
|
||||
failures: 0,
|
||||
base_duration,
|
||||
max_duration,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delay(&mut self) {
|
||||
tokio::time::sleep(self.next()).await
|
||||
}
|
||||
|
||||
pub fn next(&mut self) -> Duration {
|
||||
self.failures += 1;
|
||||
let duration = self
|
||||
.base_duration
|
||||
.checked_mul(self.failures)
|
||||
.unwrap_or(self.max_duration);
|
||||
std::cmp::min(duration, self.max_duration)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.failures = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleans up the hostname so it can be used as a tunnel name.
|
||||
/// See TUNNEL_NAME_PATTERN in the tunnels SDK for the rules we try to use.
|
||||
fn clean_hostname_for_tunnel(hostname: &str) -> String {
|
||||
let mut out = String::new();
|
||||
for char in hostname.chars().take(60) {
|
||||
match char {
|
||||
'-' | '_' | ' ' => {
|
||||
out.push('-');
|
||||
}
|
||||
'0'..='9' | 'a'..='z' | 'A'..='Z' => {
|
||||
out.push(char);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let trimmed = out.trim_matches('-');
|
||||
if trimmed.len() < 2 {
|
||||
"remote-machine".to_string() // placeholder if the result was empty
|
||||
} else {
|
||||
trimmed.to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_clean_hostname_for_tunnel() {
|
||||
assert_eq!(
|
||||
clean_hostname_for_tunnel("hello123"),
|
||||
"hello123".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
clean_hostname_for_tunnel("-cool-name-"),
|
||||
"cool-name".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
clean_hostname_for_tunnel("cool!name with_chars"),
|
||||
"coolname-with-chars".to_string()
|
||||
);
|
||||
assert_eq!(clean_hostname_for_tunnel("z"), "remote-machine".to_string());
|
||||
}
|
||||
}
|
||||
66
cli/src/tunnels/legal.rs
Normal file
66
cli/src/tunnels/legal.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use crate::constants::{IS_INTERACTIVE_CLI, PRODUCT_NAME_LONG};
|
||||
use crate::state::{LauncherPaths, PersistedState};
|
||||
use crate::util::errors::{AnyError, MissingLegalConsent};
|
||||
use crate::util::input::prompt_yn;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const LICENSE_TEXT: Option<&'static str> = option_env!("VSCODE_CLI_REMOTE_LICENSE_TEXT");
|
||||
const LICENSE_PROMPT: Option<&'static str> = option_env!("VSCODE_CLI_REMOTE_LICENSE_PROMPT");
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
struct PersistedConsent {
|
||||
pub consented: Option<bool>,
|
||||
}
|
||||
|
||||
pub fn require_consent(
|
||||
paths: &LauncherPaths,
|
||||
accept_server_license_terms: bool,
|
||||
) -> Result<(), AnyError> {
|
||||
match LICENSE_TEXT {
|
||||
Some(t) => println!("{}", t.replace("\\n", "\r\n")),
|
||||
None => return Ok(()),
|
||||
}
|
||||
|
||||
let prompt = match LICENSE_PROMPT {
|
||||
Some(p) => p,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let license: PersistedState<PersistedConsent> =
|
||||
PersistedState::new(paths.root().join("license_consent.json"));
|
||||
|
||||
let mut load = license.load();
|
||||
if let Some(true) = load.consented {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if accept_server_license_terms {
|
||||
load.consented = Some(true);
|
||||
} else if !*IS_INTERACTIVE_CLI {
|
||||
return Err(MissingLegalConsent(
|
||||
"Run this command again with --accept-server-license-terms to indicate your agreement."
|
||||
.to_string(),
|
||||
)
|
||||
.into());
|
||||
} else {
|
||||
match prompt_yn(prompt) {
|
||||
Ok(true) => {
|
||||
load.consented = Some(true);
|
||||
}
|
||||
Ok(false) => {
|
||||
return Err(AnyError::from(MissingLegalConsent(format!(
|
||||
"Sorry you cannot use {} CLI without accepting the terms.",
|
||||
PRODUCT_NAME_LONG
|
||||
))))
|
||||
}
|
||||
Err(e) => return Err(AnyError::from(MissingLegalConsent(e.to_string()))),
|
||||
}
|
||||
}
|
||||
|
||||
license.save(load)?;
|
||||
Ok(())
|
||||
}
|
||||
13
cli/src/tunnels/nosleep.rs
Normal file
13
cli/src/tunnels/nosleep.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub type SleepInhibitor = super::nosleep_windows::SleepInhibitor;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub type SleepInhibitor = super::nosleep_linux::SleepInhibitor;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub type SleepInhibitor = super::nosleep_macos::SleepInhibitor;
|
||||
79
cli/src/tunnels/nosleep_linux.rs
Normal file
79
cli/src/tunnels/nosleep_linux.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use zbus::{dbus_proxy, Connection};
|
||||
|
||||
use crate::{
|
||||
constants::APPLICATION_NAME,
|
||||
util::errors::{wrap, AnyError},
|
||||
};
|
||||
|
||||
/// An basically undocumented API, but seems widely implemented, and is what
|
||||
/// browsers use for sleep inhibition. The downside is that it also *may*
|
||||
/// disable the screensaver. A much better and more granular API is available
|
||||
/// on `org.freedesktop.login1.Manager`, but this requires administrative
|
||||
/// permission to request inhibition, which is not possible here.
|
||||
///
|
||||
/// See https://source.chromium.org/chromium/chromium/src/+/main:services/device/wake_lock/power_save_blocker/power_save_blocker_linux.cc;l=54;drc=2e85357a8b76996981cc6f783853a49df2cedc3a
|
||||
#[dbus_proxy(
|
||||
interface = "org.freedesktop.PowerManagement.Inhibit",
|
||||
gen_blocking = false,
|
||||
default_service = "org.freedesktop.PowerManagement.Inhibit",
|
||||
default_path = "/org/freedesktop/PowerManagement/Inhibit"
|
||||
)]
|
||||
trait PMInhibitor {
|
||||
#[dbus_proxy(name = "Inhibit")]
|
||||
fn inhibit(&self, what: &str, why: &str) -> zbus::Result<u32>;
|
||||
}
|
||||
|
||||
/// A slightly better documented version which seems commonly used.
|
||||
#[dbus_proxy(
|
||||
interface = "org.freedesktop.ScreenSaver",
|
||||
gen_blocking = false,
|
||||
default_service = "org.freedesktop.ScreenSaver",
|
||||
default_path = "/org/freedesktop/ScreenSaver"
|
||||
)]
|
||||
trait ScreenSaver {
|
||||
#[dbus_proxy(name = "Inhibit")]
|
||||
fn inhibit(&self, what: &str, why: &str) -> zbus::Result<u32>;
|
||||
}
|
||||
|
||||
pub struct SleepInhibitor {
|
||||
_connection: Connection, // Inhibition is released when the connection is closed
|
||||
}
|
||||
|
||||
impl SleepInhibitor {
|
||||
pub async fn new() -> Result<Self, AnyError> {
|
||||
let connection = Connection::session()
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error creating dbus session"))?;
|
||||
|
||||
macro_rules! try_inhibit {
|
||||
($proxy:ident) => {
|
||||
match $proxy::new(&connection).await {
|
||||
Ok(proxy) => proxy.inhibit(APPLICATION_NAME, "running tunnel").await,
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
if let Err(e1) = try_inhibit!(PMInhibitorProxy) {
|
||||
if let Err(e2) = try_inhibit!(ScreenSaverProxy) {
|
||||
return Err(wrap(
|
||||
e2,
|
||||
format!(
|
||||
"error requesting sleep inhibition, pminhibitor gave {}, screensaver gave",
|
||||
e1
|
||||
),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(SleepInhibitor {
|
||||
_connection: connection,
|
||||
})
|
||||
}
|
||||
}
|
||||
78
cli/src/tunnels/nosleep_macos.rs
Normal file
78
cli/src/tunnels/nosleep_macos.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::io;
|
||||
|
||||
use core_foundation::base::TCFType;
|
||||
use core_foundation::string::{CFString, CFStringRef};
|
||||
use libc::c_int;
|
||||
|
||||
use crate::constants::TUNNEL_ACTIVITY_NAME;
|
||||
|
||||
extern "C" {
|
||||
pub fn IOPMAssertionCreateWithName(
|
||||
assertion_type: CFStringRef,
|
||||
assertion_level: u32,
|
||||
assertion_name: CFStringRef,
|
||||
assertion_id: &mut u32,
|
||||
) -> c_int;
|
||||
|
||||
pub fn IOPMAssertionRelease(assertion_id: u32) -> c_int;
|
||||
}
|
||||
|
||||
const NUM_ASSERTIONS: usize = 2;
|
||||
|
||||
const ASSERTIONS: [&str; NUM_ASSERTIONS] = ["PreventUserIdleSystemSleep", "PreventSystemSleep"];
|
||||
|
||||
struct Assertion(u32);
|
||||
|
||||
impl Assertion {
|
||||
pub fn make(typ: &CFString, name: &CFString) -> io::Result<Self> {
|
||||
let mut assertion_id = 0;
|
||||
let result = unsafe {
|
||||
IOPMAssertionCreateWithName(
|
||||
typ.as_concrete_TypeRef(),
|
||||
255,
|
||||
name.as_concrete_TypeRef(),
|
||||
&mut assertion_id,
|
||||
)
|
||||
};
|
||||
|
||||
if result != 0 {
|
||||
Err(io::Error::last_os_error())
|
||||
} else {
|
||||
Ok(Self(assertion_id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Assertion {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
IOPMAssertionRelease(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SleepInhibitor {
|
||||
_assertions: Vec<Assertion>,
|
||||
}
|
||||
|
||||
impl SleepInhibitor {
|
||||
pub async fn new() -> io::Result<Self> {
|
||||
let mut assertions = Vec::with_capacity(NUM_ASSERTIONS);
|
||||
let assertion_name = CFString::from_static_string(TUNNEL_ACTIVITY_NAME);
|
||||
for typ in ASSERTIONS {
|
||||
assertions.push(Assertion::make(
|
||||
&CFString::from_static_string(typ),
|
||||
&assertion_name,
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
_assertions: assertions,
|
||||
})
|
||||
}
|
||||
}
|
||||
79
cli/src/tunnels/nosleep_windows.rs
Normal file
79
cli/src/tunnels/nosleep_windows.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::io;
|
||||
|
||||
use winapi::{
|
||||
ctypes::c_void,
|
||||
um::{
|
||||
handleapi::CloseHandle,
|
||||
minwinbase::REASON_CONTEXT,
|
||||
winbase::{PowerClearRequest, PowerCreateRequest, PowerSetRequest},
|
||||
winnt::{
|
||||
PowerRequestSystemRequired, POWER_REQUEST_CONTEXT_SIMPLE_STRING,
|
||||
POWER_REQUEST_CONTEXT_VERSION, POWER_REQUEST_TYPE,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
use crate::constants::TUNNEL_ACTIVITY_NAME;
|
||||
|
||||
struct Request(*mut c_void);
|
||||
|
||||
impl Request {
|
||||
pub fn new() -> io::Result<Self> {
|
||||
let mut reason: Vec<u16> = TUNNEL_ACTIVITY_NAME.encode_utf16().collect();
|
||||
let mut context = REASON_CONTEXT {
|
||||
Version: POWER_REQUEST_CONTEXT_VERSION,
|
||||
Flags: POWER_REQUEST_CONTEXT_SIMPLE_STRING,
|
||||
..Default::default()
|
||||
};
|
||||
unsafe { *context.Reason.SimpleReasonString_mut() = reason.as_mut_ptr() };
|
||||
|
||||
let request = unsafe { PowerCreateRequest(&mut context) };
|
||||
if request.is_null() {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(Self(request))
|
||||
}
|
||||
|
||||
pub fn set(&self, request_type: POWER_REQUEST_TYPE) -> io::Result<()> {
|
||||
let result = unsafe { PowerSetRequest(self.0, request_type) };
|
||||
if result == 0 {
|
||||
return Err(io::Error::last_os_error());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Request {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
CloseHandle(self.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SleepInhibitor {
|
||||
request: Request,
|
||||
}
|
||||
|
||||
impl SleepInhibitor {
|
||||
pub async fn new() -> io::Result<Self> {
|
||||
let request = Request::new()?;
|
||||
request.set(PowerRequestSystemRequired)?;
|
||||
Ok(Self { request })
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SleepInhibitor {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
PowerClearRequest(self.request.0, PowerRequestSystemRequired);
|
||||
}
|
||||
}
|
||||
}
|
||||
154
cli/src/tunnels/paths.rs
Normal file
154
cli/src/tunnels/paths.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
fs::{read_dir, read_to_string, remove_dir_all, write},
|
||||
path::PathBuf,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
options::{self, Quality},
|
||||
state::LauncherPaths,
|
||||
util::{
|
||||
errors::{wrap, AnyError, WrappedError},
|
||||
machine,
|
||||
},
|
||||
};
|
||||
|
||||
pub const SERVER_FOLDER_NAME: &str = "server";
|
||||
|
||||
pub struct ServerPaths {
|
||||
// Directory into which the server is downloaded
|
||||
pub server_dir: PathBuf,
|
||||
// Executable path, within the server_id
|
||||
pub executable: PathBuf,
|
||||
// File where logs for the server should be written.
|
||||
pub logfile: PathBuf,
|
||||
// File where the process ID for the server should be written.
|
||||
pub pidfile: PathBuf,
|
||||
}
|
||||
|
||||
impl ServerPaths {
|
||||
// Queries the system to determine the process ID of the running server.
|
||||
// Returns the process ID, if the server is running.
|
||||
pub fn get_running_pid(&self) -> Option<u32> {
|
||||
if let Some(pid) = self.read_pid() {
|
||||
return match machine::process_at_path_exists(pid, &self.executable) {
|
||||
true => Some(pid),
|
||||
false => None,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(pid) = machine::find_running_process(&self.executable) {
|
||||
// attempt to backfill process ID:
|
||||
self.write_pid(pid).ok();
|
||||
return Some(pid);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Delete the server directory
|
||||
pub fn delete(&self) -> Result<(), WrappedError> {
|
||||
remove_dir_all(&self.server_dir).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("error deleting server dir {}", self.server_dir.display()),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// VS Code Server pid
|
||||
pub fn write_pid(&self, pid: u32) -> Result<(), WrappedError> {
|
||||
write(&self.pidfile, format!("{}", pid)).map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
format!("error writing process id into {}", self.pidfile.display()),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn read_pid(&self) -> Option<u32> {
|
||||
read_to_string(&self.pidfile)
|
||||
.ok()
|
||||
.and_then(|s| s.parse::<u32>().ok())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
pub struct InstalledServer {
|
||||
pub quality: options::Quality,
|
||||
pub commit: String,
|
||||
pub headless: bool,
|
||||
}
|
||||
|
||||
impl InstalledServer {
|
||||
/// Gets path information about where a specific server should be stored.
|
||||
pub fn server_paths(&self, p: &LauncherPaths) -> ServerPaths {
|
||||
let server_dir = self.get_install_folder(p);
|
||||
ServerPaths {
|
||||
executable: server_dir
|
||||
.join(SERVER_FOLDER_NAME)
|
||||
.join("bin")
|
||||
.join(self.quality.server_entrypoint()),
|
||||
logfile: server_dir.join("log.txt"),
|
||||
pidfile: server_dir.join("pid.txt"),
|
||||
server_dir,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_install_folder(&self, p: &LauncherPaths) -> PathBuf {
|
||||
p.server_cache.path().join(if !self.headless {
|
||||
format!("{}-web", get_server_folder_name(self.quality, &self.commit))
|
||||
} else {
|
||||
get_server_folder_name(self.quality, &self.commit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Prunes servers not currently running, and returns the deleted servers.
|
||||
pub fn prune_stopped_servers(launcher_paths: &LauncherPaths) -> Result<Vec<ServerPaths>, AnyError> {
|
||||
get_all_servers(launcher_paths)
|
||||
.into_iter()
|
||||
.map(|s| s.server_paths(launcher_paths))
|
||||
.filter(|s| s.get_running_pid().is_none())
|
||||
.map(|s| s.delete().map(|_| s))
|
||||
.collect::<Result<_, _>>()
|
||||
.map_err(AnyError::from)
|
||||
}
|
||||
|
||||
// Gets a list of all servers which look like they might be running.
|
||||
pub fn get_all_servers(lp: &LauncherPaths) -> Vec<InstalledServer> {
|
||||
let mut servers: Vec<InstalledServer> = vec![];
|
||||
if let Ok(children) = read_dir(lp.server_cache.path()) {
|
||||
for child in children.flatten() {
|
||||
let fname = child.file_name();
|
||||
let fname = fname.to_string_lossy();
|
||||
let (quality, commit) = match fname.split_once('-') {
|
||||
Some(r) => r,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let quality = match options::Quality::try_from(quality) {
|
||||
Ok(q) => q,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
servers.push(InstalledServer {
|
||||
quality,
|
||||
commit: commit.to_string(),
|
||||
headless: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
servers
|
||||
}
|
||||
|
||||
pub fn get_server_folder_name(quality: Quality, commit: &str) -> String {
|
||||
format!("{}-{}", quality, commit)
|
||||
}
|
||||
131
cli/src/tunnels/port_forwarder.rs
Normal file
131
cli/src/tunnels/port_forwarder.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
use crate::{
|
||||
constants::CONTROL_PORT,
|
||||
util::errors::{AnyError, CannotForwardControlPort, ServerHasClosed},
|
||||
};
|
||||
|
||||
use super::dev_tunnels::ActiveTunnel;
|
||||
|
||||
pub enum PortForwardingRec {
|
||||
Forward(u16, oneshot::Sender<Result<String, AnyError>>),
|
||||
Unforward(u16, oneshot::Sender<Result<(), AnyError>>),
|
||||
}
|
||||
|
||||
/// Provides a port forwarding service for connected clients. Clients can make
|
||||
/// requests on it, which are (and *must be*) processed by calling the `.process()`
|
||||
/// method on the forwarder.
|
||||
pub struct PortForwardingProcessor {
|
||||
tx: mpsc::Sender<PortForwardingRec>,
|
||||
rx: mpsc::Receiver<PortForwardingRec>,
|
||||
forwarded: HashSet<u16>,
|
||||
}
|
||||
|
||||
impl PortForwardingProcessor {
|
||||
pub fn new() -> Self {
|
||||
let (tx, rx) = mpsc::channel(8);
|
||||
Self {
|
||||
tx,
|
||||
rx,
|
||||
forwarded: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a handle that can be passed off to consumers of port forwarding.
|
||||
pub fn handle(&self) -> PortForwarding {
|
||||
PortForwarding {
|
||||
tx: self.tx.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Receives port forwarding requests. Consumers MUST call `process()`
|
||||
/// with the received requests.
|
||||
pub async fn recv(&mut self) -> Option<PortForwardingRec> {
|
||||
self.rx.recv().await
|
||||
}
|
||||
|
||||
/// Processes the incoming forwarding request.
|
||||
pub async fn process(&mut self, req: PortForwardingRec, tunnel: &mut ActiveTunnel) {
|
||||
match req {
|
||||
PortForwardingRec::Forward(port, tx) => {
|
||||
tx.send(self.process_forward(port, tunnel).await).ok();
|
||||
}
|
||||
PortForwardingRec::Unforward(port, tx) => {
|
||||
tx.send(self.process_unforward(port, tunnel).await).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_unforward(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tunnel: &mut ActiveTunnel,
|
||||
) -> Result<(), AnyError> {
|
||||
if port == CONTROL_PORT {
|
||||
return Err(CannotForwardControlPort().into());
|
||||
}
|
||||
|
||||
tunnel.remove_port(port).await?;
|
||||
self.forwarded.remove(&port);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_forward(
|
||||
&mut self,
|
||||
port: u16,
|
||||
tunnel: &mut ActiveTunnel,
|
||||
) -> Result<String, AnyError> {
|
||||
if port == CONTROL_PORT {
|
||||
return Err(CannotForwardControlPort().into());
|
||||
}
|
||||
|
||||
if !self.forwarded.contains(&port) {
|
||||
tunnel.add_port_tcp(port).await?;
|
||||
self.forwarded.insert(port);
|
||||
}
|
||||
|
||||
tunnel.get_port_uri(port).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PortForwarding {
|
||||
tx: mpsc::Sender<PortForwardingRec>,
|
||||
}
|
||||
|
||||
impl PortForwarding {
|
||||
pub async fn forward(&self, port: u16) -> Result<String, AnyError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let req = PortForwardingRec::Forward(port, tx);
|
||||
|
||||
if self.tx.send(req).await.is_err() {
|
||||
return Err(ServerHasClosed().into());
|
||||
}
|
||||
|
||||
match rx.await {
|
||||
Ok(r) => r,
|
||||
Err(_) => Err(ServerHasClosed().into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn unforward(&self, port: u16) -> Result<(), AnyError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let req = PortForwardingRec::Unforward(port, tx);
|
||||
|
||||
if self.tx.send(req).await.is_err() {
|
||||
return Err(ServerHasClosed().into());
|
||||
}
|
||||
|
||||
match rx.await {
|
||||
Ok(r) => r,
|
||||
Err(_) => Err(ServerHasClosed().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
249
cli/src/tunnels/protocol.rs
Normal file
249
cli/src/tunnels/protocol.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
constants::{PROTOCOL_VERSION, VSCODE_CLI_VERSION},
|
||||
options::Quality,
|
||||
update_service::Platform,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(tag = "method", content = "params", rename_all = "camelCase")]
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum ClientRequestMethod<'a> {
|
||||
servermsg(RefServerMessageParams<'a>),
|
||||
serverlog(ServerLog<'a>),
|
||||
makehttpreq(HttpRequestParams<'a>),
|
||||
version(VersionResponse),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct HttpBodyParams {
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub segment: Vec<u8>,
|
||||
pub complete: bool,
|
||||
pub req_id: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct HttpRequestParams<'a> {
|
||||
pub url: &'a str,
|
||||
pub method: &'static str,
|
||||
pub req_id: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct HttpHeadersParams {
|
||||
pub status_code: u16,
|
||||
pub headers: Vec<(String, String)>,
|
||||
pub req_id: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ForwardParams {
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct UnforwardParams {
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ForwardResult {
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ServeParams {
|
||||
pub socket_id: u16,
|
||||
pub commit_id: Option<String>,
|
||||
pub quality: Quality,
|
||||
pub extensions: Vec<String>,
|
||||
/// Optional preferred connection token.
|
||||
#[serde(default)]
|
||||
pub connection_token: Option<String>,
|
||||
#[serde(default)]
|
||||
pub use_local_download: bool,
|
||||
/// If true, the client and server should gzip servermsg's sent in either direction.
|
||||
#[serde(default)]
|
||||
pub compress: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct EmptyObject {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct UpdateParams {
|
||||
pub do_update: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ServerMessageParams {
|
||||
pub i: u16,
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub body: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct RefServerMessageParams<'a> {
|
||||
pub i: u16,
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub body: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct UpdateResult {
|
||||
pub up_to_date: bool,
|
||||
pub did_update: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ToClientRequest<'a> {
|
||||
pub id: Option<u32>,
|
||||
#[serde(flatten)]
|
||||
pub params: ClientRequestMethod<'a>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize)]
|
||||
pub struct ServerLog<'a> {
|
||||
pub line: &'a str,
|
||||
pub level: u8,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GetHostnameResponse {
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GetEnvResponse {
|
||||
pub env: HashMap<String, String>,
|
||||
pub os_platform: &'static str,
|
||||
pub os_release: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct FsStatRequest {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct FsStatResponse {
|
||||
pub exists: bool,
|
||||
pub size: Option<u64>,
|
||||
#[serde(rename = "type")]
|
||||
pub kind: Option<&'static str>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct CallServerHttpParams {
|
||||
pub path: String,
|
||||
pub method: String,
|
||||
pub headers: HashMap<String, String>,
|
||||
pub body: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct CallServerHttpResult {
|
||||
pub status: u16,
|
||||
#[serde(with = "serde_bytes")]
|
||||
pub body: Vec<u8>,
|
||||
pub headers: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct VersionResponse {
|
||||
pub version: &'static str,
|
||||
pub protocol_version: u32,
|
||||
}
|
||||
|
||||
impl Default for VersionResponse {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
version: VSCODE_CLI_VERSION.unwrap_or("dev"),
|
||||
protocol_version: PROTOCOL_VERSION,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SpawnParams {
|
||||
pub command: String,
|
||||
pub args: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub cwd: Option<String>,
|
||||
#[serde(default)]
|
||||
pub env: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AcquireCliParams {
|
||||
pub platform: Platform,
|
||||
pub quality: Quality,
|
||||
pub commit_id: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub spawn: SpawnParams,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct SpawnResult {
|
||||
pub message: String,
|
||||
pub exit_code: i32,
|
||||
}
|
||||
|
||||
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
|
||||
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ChallengeIssueResponse {
|
||||
pub challenge: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct ChallengeVerifyParams {
|
||||
pub response: String,
|
||||
}
|
||||
|
||||
pub mod singleton {
|
||||
use crate::log;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const METHOD_RESTART: &str = "restart";
|
||||
pub const METHOD_SHUTDOWN: &str = "shutdown";
|
||||
pub const METHOD_STATUS: &str = "status";
|
||||
pub const METHOD_LOG: &str = "log";
|
||||
pub const METHOD_LOG_REPLY_DONE: &str = "log_done";
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LogMessage<'a> {
|
||||
pub level: Option<log::Level>,
|
||||
pub prefix: &'a str,
|
||||
pub message: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LogMessageOwned {
|
||||
pub level: Option<log::Level>,
|
||||
pub prefix: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Status {
|
||||
pub tunnel: TunnelState,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub struct LogReplayFinished {}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub enum TunnelState {
|
||||
Disconnected,
|
||||
Connected { name: String },
|
||||
}
|
||||
}
|
||||
62
cli/src/tunnels/server_bridge.rs
Normal file
62
cli/src/tunnels/server_bridge.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
use super::socket_signal::{ClientMessageDecoder, ServerMessageSink};
|
||||
use crate::{
|
||||
async_pipe::{get_socket_rw_stream, socket_stream_split, AsyncPipeWriteHalf},
|
||||
util::errors::AnyError,
|
||||
};
|
||||
use std::path::Path;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
pub struct ServerBridge {
|
||||
write: AsyncPipeWriteHalf,
|
||||
decoder: ClientMessageDecoder,
|
||||
}
|
||||
|
||||
const BUFFER_SIZE: usize = 65536;
|
||||
|
||||
impl ServerBridge {
|
||||
pub async fn new(
|
||||
path: &Path,
|
||||
mut target: ServerMessageSink,
|
||||
decoder: ClientMessageDecoder,
|
||||
) -> Result<Self, AnyError> {
|
||||
let stream = get_socket_rw_stream(path).await?;
|
||||
let (mut read, write) = socket_stream_split(stream);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut read_buf = vec![0; BUFFER_SIZE];
|
||||
loop {
|
||||
match read.read(&mut read_buf).await {
|
||||
Err(_) => return,
|
||||
Ok(0) => {
|
||||
return; // EOF
|
||||
}
|
||||
Ok(s) => {
|
||||
let send = target.server_message(&read_buf[..s]).await;
|
||||
if send.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(ServerBridge { write, decoder })
|
||||
}
|
||||
|
||||
pub async fn write(&mut self, b: Vec<u8>) -> std::io::Result<()> {
|
||||
let dec = self.decoder.decode(&b)?;
|
||||
if !dec.is_empty() {
|
||||
self.write.write_all(dec).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn close(mut self) -> std::io::Result<()> {
|
||||
self.write.shutdown().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
145
cli/src/tunnels/server_multiplexer.rs
Normal file
145
cli/src/tunnels/server_multiplexer.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::future::join_all;
|
||||
|
||||
use crate::log;
|
||||
|
||||
use super::server_bridge::ServerBridge;
|
||||
|
||||
type Inner = Arc<std::sync::Mutex<Option<Vec<ServerBridgeRec>>>>;
|
||||
|
||||
struct ServerBridgeRec {
|
||||
id: u16,
|
||||
// bridge is removed when there's a write loop currently active
|
||||
bridge: Option<ServerBridge>,
|
||||
write_queue: Vec<Vec<u8>>,
|
||||
}
|
||||
|
||||
/// The ServerMultiplexer manages multiple server bridges and allows writing
|
||||
/// to them in a thread-safe way. It is copy, sync, and clone.
|
||||
#[derive(Clone)]
|
||||
pub struct ServerMultiplexer {
|
||||
inner: Inner,
|
||||
}
|
||||
|
||||
impl ServerMultiplexer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(std::sync::Mutex::new(Some(Vec::new()))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a new bridge to the multiplexer.
|
||||
pub fn register(&self, id: u16, bridge: ServerBridge) {
|
||||
let bridge_rec = ServerBridgeRec {
|
||||
id,
|
||||
bridge: Some(bridge),
|
||||
write_queue: vec![],
|
||||
};
|
||||
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
match &mut *lock {
|
||||
Some(server_bridges) => (*server_bridges).push(bridge_rec),
|
||||
None => *lock = Some(vec![bridge_rec]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes a server bridge by ID.
|
||||
pub fn remove(&self, id: u16) {
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
if let Some(bridges) = &mut *lock {
|
||||
bridges.retain(|sb| sb.id != id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an incoming server message. This is synchronous and uses a 'write loop'
|
||||
/// to ensure message order is preserved exactly, which is necessary for compression.
|
||||
/// Returns false if there was no server with the given bridge_id.
|
||||
pub fn write_message(&self, log: &log::Logger, bridge_id: u16, message: Vec<u8>) -> bool {
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
|
||||
let bridges = match &mut *lock {
|
||||
Some(sb) => sb,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
let record = match bridges.iter_mut().find(|b| b.id == bridge_id) {
|
||||
Some(sb) => sb,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
record.write_queue.push(message);
|
||||
if let Some(bridge) = record.bridge.take() {
|
||||
let bridges_lock = self.inner.clone();
|
||||
let log = log.clone();
|
||||
tokio::spawn(write_loop(log, record.id, bridge, bridges_lock));
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Disposes all running server bridges.
|
||||
pub async fn dispose(&self) {
|
||||
let bridges = {
|
||||
let mut lock = self.inner.lock().unwrap();
|
||||
lock.take()
|
||||
};
|
||||
|
||||
let bridges = match bridges {
|
||||
Some(b) => b,
|
||||
None => return,
|
||||
};
|
||||
|
||||
join_all(
|
||||
bridges
|
||||
.into_iter()
|
||||
.filter_map(|b| b.bridge)
|
||||
.map(|b| b.close()),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Write loop started by `handle_server_message`. It takes the ServerBridge, and
|
||||
/// runs until there's no more items in the 'write queue'. At that point, if the
|
||||
/// record still exists in the bridges_lock (i.e. we haven't shut down), it'll
|
||||
/// return the ServerBridge so that the next handle_server_message call starts
|
||||
/// the loop again. Otherwise, it'll close the bridge.
|
||||
async fn write_loop(log: log::Logger, id: u16, mut bridge: ServerBridge, bridges_lock: Inner) {
|
||||
let mut items_vec = vec![];
|
||||
loop {
|
||||
{
|
||||
let mut lock = bridges_lock.lock().unwrap();
|
||||
let server_bridges = match &mut *lock {
|
||||
Some(sb) => sb,
|
||||
None => break,
|
||||
};
|
||||
|
||||
let bridge_rec = match server_bridges.iter_mut().find(|b| id == b.id) {
|
||||
Some(b) => b,
|
||||
None => break,
|
||||
};
|
||||
|
||||
if bridge_rec.write_queue.is_empty() {
|
||||
bridge_rec.bridge = Some(bridge);
|
||||
return;
|
||||
}
|
||||
|
||||
std::mem::swap(&mut bridge_rec.write_queue, &mut items_vec);
|
||||
}
|
||||
|
||||
for item in items_vec.drain(..) {
|
||||
if let Err(e) = bridge.write(item).await {
|
||||
warning!(log, "Error writing to server: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bridge.close().await.ok(); // got here from `break` above, meaning our record got cleared. Close the bridge if so
|
||||
}
|
||||
92
cli/src/tunnels/service.rs
Normal file
92
cli/src/tunnels/service.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::log;
|
||||
use crate::state::LauncherPaths;
|
||||
use crate::util::errors::{wrap, AnyError};
|
||||
use crate::util::io::{tailf, TailEvent};
|
||||
|
||||
pub const SERVICE_LOG_FILE_NAME: &str = "tunnel-service.log";
|
||||
|
||||
#[async_trait]
|
||||
pub trait ServiceContainer: Send {
|
||||
async fn run_service(
|
||||
&mut self,
|
||||
log: log::Logger,
|
||||
launcher_paths: LauncherPaths,
|
||||
) -> Result<(), AnyError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ServiceManager {
|
||||
/// Registers the current executable as a service to run with the given set
|
||||
/// of arguments.
|
||||
async fn register(&self, exe: PathBuf, args: &[&str]) -> Result<(), AnyError>;
|
||||
|
||||
/// Runs the service using the given handle. The executable *must not* take
|
||||
/// any action which may fail prior to calling this to ensure service
|
||||
/// states may update.
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: LauncherPaths,
|
||||
handle: impl 'static + ServiceContainer,
|
||||
) -> Result<(), AnyError>;
|
||||
|
||||
/// Show logs from the running service to standard out.
|
||||
async fn show_logs(&self) -> Result<(), AnyError>;
|
||||
|
||||
/// Unregisters the current executable as a service.
|
||||
async fn unregister(&self) -> Result<(), AnyError>;
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub type ServiceManagerImpl = super::service_windows::WindowsService;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub type ServiceManagerImpl = super::service_linux::SystemdService;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub type ServiceManagerImpl = super::service_macos::LaunchdService;
|
||||
|
||||
#[allow(unreachable_code)]
|
||||
#[allow(unused_variables)]
|
||||
pub fn create_service_manager(log: log::Logger, paths: &LauncherPaths) -> ServiceManagerImpl {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
super::service_macos::LaunchdService::new(log, paths)
|
||||
}
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
super::service_windows::WindowsService::new(log, paths)
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
super::service_linux::SystemdService::new(log, paths.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)] // unused on Linux
|
||||
pub(crate) async fn tail_log_file(log_file: &Path) -> Result<(), AnyError> {
|
||||
if !log_file.exists() {
|
||||
println!("The tunnel service has not started yet.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let file = std::fs::File::open(log_file).map_err(|e| wrap(e, "error opening log file"))?;
|
||||
let mut rx = tailf(file, 20);
|
||||
while let Some(line) = rx.recv().await {
|
||||
match line {
|
||||
TailEvent::Line(l) => print!("{}", l),
|
||||
TailEvent::Reset => println!("== Tunnel service restarted =="),
|
||||
TailEvent::Err(e) => return Err(wrap(e, "error reading log file").into()),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
240
cli/src/tunnels/service_linux.rs
Normal file
240
cli/src/tunnels/service_linux.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{self, Write},
|
||||
path::PathBuf,
|
||||
process::Command,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zbus::{dbus_proxy, zvariant, Connection};
|
||||
|
||||
use crate::{
|
||||
constants::{APPLICATION_NAME, PRODUCT_NAME_LONG},
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
util::errors::{wrap, AnyError},
|
||||
};
|
||||
|
||||
use super::ServiceManager;
|
||||
|
||||
pub struct SystemdService {
|
||||
log: log::Logger,
|
||||
service_file: PathBuf,
|
||||
}
|
||||
|
||||
impl SystemdService {
|
||||
pub fn new(log: log::Logger, paths: LauncherPaths) -> Self {
|
||||
Self {
|
||||
log,
|
||||
service_file: paths.root().join(SystemdService::service_name_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SystemdService {
|
||||
async fn connect() -> Result<Connection, AnyError> {
|
||||
let connection = Connection::session()
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error creating dbus session"))?;
|
||||
Ok(connection)
|
||||
}
|
||||
|
||||
async fn proxy(connection: &Connection) -> Result<SystemdManagerDbusProxy<'_>, AnyError> {
|
||||
let proxy = SystemdManagerDbusProxy::new(connection)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
wrap(
|
||||
e,
|
||||
"error connecting to systemd, you may need to re-run with sudo:",
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(proxy)
|
||||
}
|
||||
|
||||
fn service_path_string(&self) -> String {
|
||||
self.service_file.as_os_str().to_string_lossy().to_string()
|
||||
}
|
||||
|
||||
fn service_name_string() -> String {
|
||||
format!("{}-tunnel.service", APPLICATION_NAME)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ServiceManager for SystemdService {
|
||||
async fn register(
|
||||
&self,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
let connection = SystemdService::connect().await?;
|
||||
let proxy = SystemdService::proxy(&connection).await?;
|
||||
|
||||
write_systemd_service_file(&self.service_file, exe, args)
|
||||
.map_err(|e| wrap(e, "error creating service file"))?;
|
||||
|
||||
proxy
|
||||
.link_unit_files(
|
||||
vec![self.service_path_string()],
|
||||
/* 'runtime only'= */ false,
|
||||
/* replace existing = */ true,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error registering service"))?;
|
||||
|
||||
info!(self.log, "Successfully registered service...");
|
||||
|
||||
// note: enablement is implicit in recent systemd version, but required for older systems
|
||||
// https://github.com/microsoft/vscode/issues/167489#issuecomment-1331222826
|
||||
proxy
|
||||
.enable_unit_files(
|
||||
vec![SystemdService::service_name_string()],
|
||||
/* 'runtime only'= */ false,
|
||||
/* replace existing = */ true,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error enabling unit files for service"))?;
|
||||
|
||||
info!(self.log, "Successfully enabled unit files...");
|
||||
|
||||
proxy
|
||||
.start_unit(SystemdService::service_name_string(), "replace".to_string())
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error starting service"))?;
|
||||
|
||||
info!(self.log, "Tunnel service successfully started");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: crate::state::LauncherPaths,
|
||||
mut handle: impl 'static + super::ServiceContainer,
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
handle.run_service(self.log, launcher_paths).await
|
||||
}
|
||||
|
||||
async fn show_logs(&self) -> Result<(), AnyError> {
|
||||
// show the systemctl status header...
|
||||
Command::new("systemctl")
|
||||
.args([
|
||||
"--user",
|
||||
"status",
|
||||
"-n",
|
||||
"0",
|
||||
&SystemdService::service_name_string(),
|
||||
])
|
||||
.status()
|
||||
.map(|s| s.code().unwrap_or(1))
|
||||
.map_err(|e| wrap(e, "error running systemctl"))?;
|
||||
|
||||
// then follow log files
|
||||
Command::new("journalctl")
|
||||
.args(["--user", "-f", "-u", &SystemdService::service_name_string()])
|
||||
.status()
|
||||
.map(|s| s.code().unwrap_or(1))
|
||||
.map_err(|e| wrap(e, "error running journalctl"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn unregister(&self) -> Result<(), crate::util::errors::AnyError> {
|
||||
let connection = SystemdService::connect().await?;
|
||||
let proxy = SystemdService::proxy(&connection).await?;
|
||||
|
||||
proxy
|
||||
.stop_unit(SystemdService::service_name_string(), "replace".to_string())
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error unregistering service"))?;
|
||||
|
||||
info!(self.log, "Successfully stopped service...");
|
||||
|
||||
proxy
|
||||
.disable_unit_files(
|
||||
vec![SystemdService::service_name_string()],
|
||||
/* 'runtime only'= */ false,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| wrap(e, "error unregistering service"))?;
|
||||
|
||||
info!(self.log, "Tunnel service uninstalled");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn write_systemd_service_file(
|
||||
path: &PathBuf,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> io::Result<()> {
|
||||
let mut f = File::create(path)?;
|
||||
write!(
|
||||
&mut f,
|
||||
"[Unit]\n\
|
||||
Description={} Tunnel\n\
|
||||
After=network.target\n\
|
||||
StartLimitIntervalSec=0\n\
|
||||
\n\
|
||||
[Service]\n\
|
||||
Type=simple\n\
|
||||
Restart=always\n\
|
||||
RestartSec=10\n\
|
||||
ExecStart={} \"{}\"\n\
|
||||
\n\
|
||||
[Install]\n\
|
||||
WantedBy=default.target\n\
|
||||
",
|
||||
PRODUCT_NAME_LONG,
|
||||
exe.into_os_string().to_string_lossy(),
|
||||
args.join("\" \"")
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Minimal implementation of systemd types for the services we need. The full
|
||||
/// definition can be found on any systemd machine with the command:
|
||||
///
|
||||
/// gdbus introspect --system --dest org.freedesktop.systemd1 --object-path /org/freedesktop/systemd1
|
||||
///
|
||||
/// See docs here: https://www.freedesktop.org/software/systemd/man/org.freedesktop.systemd1.html
|
||||
#[dbus_proxy(
|
||||
interface = "org.freedesktop.systemd1.Manager",
|
||||
gen_blocking = false,
|
||||
default_service = "org.freedesktop.systemd1",
|
||||
default_path = "/org/freedesktop/systemd1"
|
||||
)]
|
||||
trait SystemdManagerDbus {
|
||||
#[dbus_proxy(name = "EnableUnitFiles")]
|
||||
fn enable_unit_files(
|
||||
&self,
|
||||
files: Vec<String>,
|
||||
runtime: bool,
|
||||
force: bool,
|
||||
) -> zbus::Result<(bool, Vec<(String, String, String)>)>;
|
||||
|
||||
fn link_unit_files(
|
||||
&self,
|
||||
files: Vec<String>,
|
||||
runtime: bool,
|
||||
force: bool,
|
||||
) -> zbus::Result<Vec<(String, String, String)>>;
|
||||
|
||||
fn disable_unit_files(
|
||||
&self,
|
||||
files: Vec<String>,
|
||||
runtime: bool,
|
||||
) -> zbus::Result<Vec<(String, String, String)>>;
|
||||
|
||||
#[dbus_proxy(name = "StartUnit")]
|
||||
fn start_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
|
||||
|
||||
#[dbus_proxy(name = "StopUnit")]
|
||||
fn stop_unit(&self, name: String, mode: String) -> zbus::Result<zvariant::OwnedObjectPath>;
|
||||
}
|
||||
163
cli/src/tunnels/service_macos.rs
Normal file
163
cli/src/tunnels/service_macos.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
fs::{remove_file, File},
|
||||
io::{self, Write},
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
constants::APPLICATION_NAME,
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
util::{
|
||||
command::capture_command_and_check_status,
|
||||
errors::{wrap, AnyError, CodeError, MissingHomeDirectory},
|
||||
},
|
||||
};
|
||||
|
||||
use super::{service::tail_log_file, ServiceManager};
|
||||
|
||||
pub struct LaunchdService {
|
||||
log: log::Logger,
|
||||
log_file: PathBuf,
|
||||
}
|
||||
|
||||
impl LaunchdService {
|
||||
pub fn new(log: log::Logger, paths: &LauncherPaths) -> Self {
|
||||
Self {
|
||||
log,
|
||||
log_file: paths.service_log_file(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ServiceManager for LaunchdService {
|
||||
async fn register(
|
||||
&self,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
let service_file = get_service_file_path()?;
|
||||
write_service_file(&service_file, &self.log_file, exe, args)
|
||||
.map_err(|e| wrap(e, "error creating service file"))?;
|
||||
|
||||
info!(self.log, "Successfully registered service...");
|
||||
|
||||
capture_command_and_check_status(
|
||||
"launchctl",
|
||||
&["load", service_file.as_os_str().to_string_lossy().as_ref()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
capture_command_and_check_status("launchctl", &["start", &get_service_label()]).await?;
|
||||
|
||||
info!(self.log, "Tunnel service successfully started");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn show_logs(&self) -> Result<(), AnyError> {
|
||||
tail_log_file(&self.log_file).await
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: crate::state::LauncherPaths,
|
||||
mut handle: impl 'static + super::ServiceContainer,
|
||||
) -> Result<(), crate::util::errors::AnyError> {
|
||||
handle.run_service(self.log, launcher_paths).await
|
||||
}
|
||||
|
||||
async fn unregister(&self) -> Result<(), crate::util::errors::AnyError> {
|
||||
let service_file = get_service_file_path()?;
|
||||
|
||||
match capture_command_and_check_status("launchctl", &["stop", &get_service_label()]).await {
|
||||
Ok(_) => {}
|
||||
// status 3 == "no such process"
|
||||
Err(CodeError::CommandFailed { code, .. }) if code == 3 => {}
|
||||
Err(e) => return Err(wrap(e, "error stopping service").into()),
|
||||
};
|
||||
|
||||
info!(self.log, "Successfully stopped service...");
|
||||
|
||||
capture_command_and_check_status(
|
||||
"launchctl",
|
||||
&[
|
||||
"unload",
|
||||
service_file.as_os_str().to_string_lossy().as_ref(),
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(self.log, "Tunnel service uninstalled");
|
||||
|
||||
if let Ok(f) = get_service_file_path() {
|
||||
remove_file(f).ok();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_service_label() -> String {
|
||||
format!("com.visualstudio.{}.tunnel", APPLICATION_NAME)
|
||||
}
|
||||
|
||||
fn get_service_file_path() -> Result<PathBuf, MissingHomeDirectory> {
|
||||
match dirs::home_dir() {
|
||||
Some(mut d) => {
|
||||
d.push(format!("{}.plist", get_service_label()));
|
||||
Ok(d)
|
||||
}
|
||||
None => Err(MissingHomeDirectory()),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_service_file(
|
||||
path: &PathBuf,
|
||||
log_file: &Path,
|
||||
exe: std::path::PathBuf,
|
||||
args: &[&str],
|
||||
) -> io::Result<()> {
|
||||
let mut f = File::create(path)?;
|
||||
let log_file = log_file.as_os_str().to_string_lossy();
|
||||
// todo: we may be able to skip file logging and use the ASL instead
|
||||
// if/when we no longer need to support older macOS versions.
|
||||
write!(
|
||||
&mut f,
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
|
||||
<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n\
|
||||
<plist version=\"1.0\">\n\
|
||||
<dict>\n\
|
||||
<key>Label</key>\n\
|
||||
<string>{}</string>\n\
|
||||
<key>LimitLoadToSessionType</key>\n\
|
||||
<string>Aqua</string>\n\
|
||||
<key>ProgramArguments</key>\n\
|
||||
<array>\n\
|
||||
<string>{}</string>\n\
|
||||
<string>{}</string>\n\
|
||||
</array>\n\
|
||||
<key>KeepAlive</key>\n\
|
||||
<true/>\n\
|
||||
<key>StandardErrorPath</key>\n\
|
||||
<string>{}</string>\n\
|
||||
<key>StandardOutPath</key>\n\
|
||||
<string>{}</string>\n\
|
||||
</dict>\n\
|
||||
</plist>",
|
||||
get_service_label(),
|
||||
exe.into_os_string().to_string_lossy(),
|
||||
args.join("</string><string>"),
|
||||
log_file,
|
||||
log_file
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
118
cli/src/tunnels/service_windows.rs
Normal file
118
cli/src/tunnels/service_windows.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use async_trait::async_trait;
|
||||
use shell_escape::windows::escape as shell_escape;
|
||||
use std::{
|
||||
path::PathBuf,
|
||||
process::{Command, Stdio},
|
||||
};
|
||||
use winreg::{enums::HKEY_CURRENT_USER, RegKey};
|
||||
|
||||
use crate::{
|
||||
constants::TUNNEL_ACTIVITY_NAME,
|
||||
log,
|
||||
state::LauncherPaths,
|
||||
tunnels::{protocol, singleton_client::do_single_rpc_call},
|
||||
util::errors::{wrap, wrapdbg, AnyError},
|
||||
};
|
||||
|
||||
use super::service::{tail_log_file, ServiceContainer, ServiceManager as CliServiceManager};
|
||||
|
||||
pub struct WindowsService {
|
||||
log: log::Logger,
|
||||
tunnel_lock: PathBuf,
|
||||
log_file: PathBuf,
|
||||
}
|
||||
|
||||
impl WindowsService {
|
||||
pub fn new(log: log::Logger, paths: &LauncherPaths) -> Self {
|
||||
Self {
|
||||
log,
|
||||
tunnel_lock: paths.tunnel_lockfile(),
|
||||
log_file: paths.service_log_file(),
|
||||
}
|
||||
}
|
||||
|
||||
fn open_key() -> Result<RegKey, AnyError> {
|
||||
RegKey::predef(HKEY_CURRENT_USER)
|
||||
.create_subkey(r"Software\Microsoft\Windows\CurrentVersion\Run")
|
||||
.map_err(|e| wrap(e, "error opening run registry key").into())
|
||||
.map(|(key, _)| key)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CliServiceManager for WindowsService {
|
||||
async fn register(&self, exe: std::path::PathBuf, args: &[&str]) -> Result<(), AnyError> {
|
||||
let key = WindowsService::open_key()?;
|
||||
|
||||
let mut reg_str = String::new();
|
||||
let mut cmd = Command::new(&exe);
|
||||
reg_str.push_str(shell_escape(exe.to_string_lossy()).as_ref());
|
||||
|
||||
let mut add_arg = |arg: &str| {
|
||||
reg_str.push(' ');
|
||||
reg_str.push_str(shell_escape((*arg).into()).as_ref());
|
||||
cmd.arg(arg);
|
||||
};
|
||||
|
||||
for arg in args {
|
||||
add_arg(arg);
|
||||
}
|
||||
|
||||
add_arg("--log-to-file");
|
||||
add_arg(self.log_file.to_string_lossy().as_ref());
|
||||
|
||||
key.set_value(TUNNEL_ACTIVITY_NAME, ®_str)
|
||||
.map_err(|e| AnyError::from(wrapdbg(e, "error setting registry key")))?;
|
||||
|
||||
info!(self.log, "Successfully registered service...");
|
||||
|
||||
cmd.stderr(Stdio::null());
|
||||
cmd.stdout(Stdio::null());
|
||||
cmd.stdin(Stdio::null());
|
||||
cmd.spawn()
|
||||
.map_err(|e| wrapdbg(e, "error starting service"))?;
|
||||
|
||||
info!(self.log, "Tunnel service successfully started");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn show_logs(&self) -> Result<(), AnyError> {
|
||||
tail_log_file(&self.log_file).await
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self,
|
||||
launcher_paths: LauncherPaths,
|
||||
mut handle: impl 'static + ServiceContainer,
|
||||
) -> Result<(), AnyError> {
|
||||
handle.run_service(self.log, launcher_paths).await
|
||||
}
|
||||
|
||||
async fn unregister(&self) -> Result<(), AnyError> {
|
||||
let key = WindowsService::open_key()?;
|
||||
key.delete_value(TUNNEL_ACTIVITY_NAME)
|
||||
.map_err(|e| AnyError::from(wrap(e, "error deleting registry key")))?;
|
||||
info!(self.log, "Tunnel service uninstalled");
|
||||
|
||||
let r = do_single_rpc_call::<_, ()>(
|
||||
&self.tunnel_lock,
|
||||
self.log.clone(),
|
||||
protocol::singleton::METHOD_SHUTDOWN,
|
||||
protocol::EmptyObject {},
|
||||
)
|
||||
.await;
|
||||
|
||||
if r.is_err() {
|
||||
warning!(self.log, "The tunnel service has been unregistered, but we couldn't find a running tunnel process. You may need to restart or log out and back in to fully stop the tunnel.");
|
||||
} else {
|
||||
info!(self.log, "Successfully shut down running tunnel.");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
90
cli/src/tunnels/shutdown_signal.rs
Normal file
90
cli/src/tunnels/shutdown_signal.rs
Normal file
@@ -0,0 +1,90 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
use std::{fmt, path::PathBuf};
|
||||
use sysinfo::Pid;
|
||||
|
||||
use crate::util::{
|
||||
machine::{wait_until_exe_deleted, wait_until_process_exits},
|
||||
sync::{new_barrier, Barrier, Receivable},
|
||||
};
|
||||
|
||||
/// Describes the signal to manully stop the server
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum ShutdownSignal {
|
||||
CtrlC,
|
||||
ParentProcessKilled(Pid),
|
||||
ExeUninstalled,
|
||||
ServiceStopped,
|
||||
RpcShutdownRequested,
|
||||
RpcRestartRequested,
|
||||
}
|
||||
|
||||
impl fmt::Display for ShutdownSignal {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
ShutdownSignal::CtrlC => write!(f, "Ctrl-C received"),
|
||||
ShutdownSignal::ParentProcessKilled(p) => {
|
||||
write!(f, "Parent process {} no longer exists", p)
|
||||
}
|
||||
ShutdownSignal::ExeUninstalled => {
|
||||
write!(f, "Executable no longer exists")
|
||||
}
|
||||
ShutdownSignal::ServiceStopped => write!(f, "Service stopped"),
|
||||
ShutdownSignal::RpcShutdownRequested => write!(f, "RPC client requested shutdown"),
|
||||
ShutdownSignal::RpcRestartRequested => {
|
||||
write!(f, "RPC client requested a tunnel restart")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ShutdownRequest {
|
||||
CtrlC,
|
||||
ParentProcessKilled(Pid),
|
||||
ExeUninstalled(PathBuf),
|
||||
Derived(Box<dyn Receivable<ShutdownSignal> + Send>),
|
||||
}
|
||||
|
||||
impl ShutdownRequest {
|
||||
async fn wait(self) -> Option<ShutdownSignal> {
|
||||
match self {
|
||||
ShutdownRequest::CtrlC => {
|
||||
let ctrl_c = tokio::signal::ctrl_c();
|
||||
ctrl_c.await.ok();
|
||||
Some(ShutdownSignal::CtrlC)
|
||||
}
|
||||
ShutdownRequest::ParentProcessKilled(pid) => {
|
||||
wait_until_process_exits(pid, 2000).await;
|
||||
Some(ShutdownSignal::ParentProcessKilled(pid))
|
||||
}
|
||||
ShutdownRequest::ExeUninstalled(exe_path) => {
|
||||
wait_until_exe_deleted(&exe_path, 2000).await;
|
||||
Some(ShutdownSignal::ExeUninstalled)
|
||||
}
|
||||
ShutdownRequest::Derived(mut rx) => rx.recv_msg().await,
|
||||
}
|
||||
}
|
||||
/// Creates a receiver channel sent to once any of the signals are received.
|
||||
/// Note: does not handle ServiceStopped
|
||||
pub fn create_rx(
|
||||
signals: impl IntoIterator<Item = ShutdownRequest>,
|
||||
) -> Barrier<ShutdownSignal> {
|
||||
let (barrier, opener) = new_barrier();
|
||||
let futures = signals
|
||||
.into_iter()
|
||||
.map(|s| s.wait())
|
||||
.collect::<FuturesUnordered<_>>();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(s) = futures.filter_map(futures::future::ready).next().await {
|
||||
opener.open(s);
|
||||
}
|
||||
});
|
||||
|
||||
barrier
|
||||
}
|
||||
}
|
||||
190
cli/src/tunnels/singleton_client.rs
Normal file
190
cli/src/tunnels/singleton_client.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
path::Path,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
use const_format::concatcp;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::{
|
||||
async_pipe::{socket_stream_split, AsyncPipe},
|
||||
constants::IS_INTERACTIVE_CLI,
|
||||
json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer},
|
||||
log,
|
||||
rpc::RpcCaller,
|
||||
singleton::connect_as_client,
|
||||
tunnels::{code_server::print_listening, protocol::EmptyObject},
|
||||
util::{errors::CodeError, sync::Barrier},
|
||||
};
|
||||
|
||||
use super::{
|
||||
protocol,
|
||||
shutdown_signal::{ShutdownRequest, ShutdownSignal},
|
||||
};
|
||||
|
||||
pub struct SingletonClientArgs {
|
||||
pub log: log::Logger,
|
||||
pub stream: AsyncPipe,
|
||||
pub shutdown: Barrier<ShutdownSignal>,
|
||||
}
|
||||
|
||||
struct SingletonServerContext {
|
||||
log: log::Logger,
|
||||
exit_entirely: Arc<AtomicBool>,
|
||||
caller: RpcCaller<JsonRpcSerializer>,
|
||||
}
|
||||
|
||||
const CONTROL_INSTRUCTIONS_COMMON: &str =
|
||||
"Connected to an existing tunnel process running on this machine.";
|
||||
|
||||
const CONTROL_INSTRUCTIONS_INTERACTIVE: &str = concatcp!(
|
||||
CONTROL_INSTRUCTIONS_COMMON,
|
||||
" You can press:
|
||||
|
||||
- \"x\" + Enter to stop the tunnel and exit
|
||||
- \"r\" + Enter to restart the tunnel
|
||||
- Ctrl+C to detach
|
||||
"
|
||||
);
|
||||
|
||||
/// Serves a client singleton. Returns true if the process should exit after
|
||||
/// this returns, instead of trying to start a tunnel.
|
||||
pub async fn start_singleton_client(args: SingletonClientArgs) -> bool {
|
||||
let mut rpc = new_json_rpc();
|
||||
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
|
||||
let exit_entirely = Arc::new(AtomicBool::new(false));
|
||||
|
||||
debug!(
|
||||
args.log,
|
||||
"An existing tunnel is running on this machine, connecting to it..."
|
||||
);
|
||||
|
||||
if *IS_INTERACTIVE_CLI {
|
||||
let stdin_handle = rpc.get_caller(msg_tx.clone());
|
||||
thread::spawn(move || {
|
||||
let mut input = String::new();
|
||||
loop {
|
||||
input.truncate(0);
|
||||
match std::io::stdin().read_line(&mut input) {
|
||||
Err(_) | Ok(0) => return, // EOF or not a tty
|
||||
_ => {}
|
||||
};
|
||||
|
||||
match input.chars().next().map(|c| c.to_ascii_lowercase()) {
|
||||
Some('x') => {
|
||||
stdin_handle.notify(protocol::singleton::METHOD_SHUTDOWN, EmptyObject {});
|
||||
return;
|
||||
}
|
||||
Some('r') => {
|
||||
stdin_handle.notify(protocol::singleton::METHOD_RESTART, EmptyObject {});
|
||||
}
|
||||
Some(_) | None => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let caller = rpc.get_caller(msg_tx);
|
||||
let mut rpc = rpc.methods(SingletonServerContext {
|
||||
log: args.log.clone(),
|
||||
exit_entirely: exit_entirely.clone(),
|
||||
caller,
|
||||
});
|
||||
|
||||
rpc.register_sync(protocol::singleton::METHOD_SHUTDOWN, |_: EmptyObject, c| {
|
||||
c.exit_entirely.store(true, Ordering::SeqCst);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
rpc.register_async(
|
||||
protocol::singleton::METHOD_LOG_REPLY_DONE,
|
||||
|_: EmptyObject, c| async move {
|
||||
c.log.result(if *IS_INTERACTIVE_CLI {
|
||||
CONTROL_INSTRUCTIONS_INTERACTIVE
|
||||
} else {
|
||||
CONTROL_INSTRUCTIONS_COMMON
|
||||
});
|
||||
|
||||
let res = c.caller.call::<_, _, protocol::singleton::Status>(
|
||||
protocol::singleton::METHOD_STATUS,
|
||||
protocol::EmptyObject {},
|
||||
);
|
||||
|
||||
// we want to ensure the "listening" string always gets printed for
|
||||
// consumers (i.e. VS Code). Ask for it. If the tunnel is not currently
|
||||
// connected though, it will be soon, and that'll be in the log replays.
|
||||
if let Ok(Ok(s)) = res.await {
|
||||
if let protocol::singleton::TunnelState::Connected { name } = s.tunnel {
|
||||
print_listening(&c.log, &name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_LOG,
|
||||
|log: protocol::singleton::LogMessageOwned, c| {
|
||||
match log.level {
|
||||
Some(level) => c.log.emit(level, &format!("{}{}", log.prefix, log.message)),
|
||||
None => c.log.result(format!("{}{}", log.prefix, log.message)),
|
||||
}
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
let (read, write) = socket_stream_split(args.stream);
|
||||
let _ = start_json_rpc(rpc.build(args.log), read, write, msg_rx, args.shutdown).await;
|
||||
|
||||
exit_entirely.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
pub async fn do_single_rpc_call<
|
||||
P: serde::Serialize + 'static,
|
||||
R: serde::de::DeserializeOwned + Send + 'static,
|
||||
>(
|
||||
lock_file: &Path,
|
||||
log: log::Logger,
|
||||
method: &'static str,
|
||||
params: P,
|
||||
) -> Result<R, CodeError> {
|
||||
let client = match connect_as_client(lock_file).await {
|
||||
Ok(p) => p,
|
||||
Err(CodeError::SingletonLockfileOpenFailed(_))
|
||||
| Err(CodeError::SingletonLockedProcessExited(_)) => {
|
||||
return Err(CodeError::NoRunningTunnel);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
|
||||
let mut rpc = new_json_rpc();
|
||||
let caller = rpc.get_caller(msg_tx);
|
||||
let (read, write) = socket_stream_split(client);
|
||||
|
||||
let rpc = tokio::spawn(async move {
|
||||
start_json_rpc(
|
||||
rpc.methods(()).build(log),
|
||||
read,
|
||||
write,
|
||||
msg_rx,
|
||||
ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let r = caller.call(method, params).await.unwrap();
|
||||
rpc.abort();
|
||||
r.map_err(CodeError::TunnelRpcCallFailed)
|
||||
}
|
||||
263
cli/src/tunnels/singleton_server.rs
Normal file
263
cli/src/tunnels/singleton_server.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use std::{
|
||||
pin::Pin,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use super::{
|
||||
code_server::CodeServerArgs,
|
||||
control_server::ServerTermination,
|
||||
dev_tunnels::ActiveTunnel,
|
||||
protocol,
|
||||
shutdown_signal::{ShutdownRequest, ShutdownSignal},
|
||||
};
|
||||
use crate::{
|
||||
async_pipe::socket_stream_split,
|
||||
json_rpc::{new_json_rpc, start_json_rpc, JsonRpcSerializer},
|
||||
log,
|
||||
rpc::{RpcCaller, RpcDispatcher},
|
||||
singleton::SingletonServer,
|
||||
state::LauncherPaths,
|
||||
tunnels::code_server::print_listening,
|
||||
update_service::Platform,
|
||||
util::{
|
||||
errors::{AnyError, CodeError},
|
||||
ring_buffer::RingBuffer,
|
||||
sync::{Barrier, ConcatReceivable},
|
||||
},
|
||||
};
|
||||
use futures::future::Either;
|
||||
use tokio::{
|
||||
pin,
|
||||
sync::{broadcast, mpsc},
|
||||
task::JoinHandle,
|
||||
};
|
||||
|
||||
pub struct SingletonServerArgs<'a> {
|
||||
pub server: &'a mut RpcServer,
|
||||
pub log: log::Logger,
|
||||
pub tunnel: ActiveTunnel,
|
||||
pub paths: &'a LauncherPaths,
|
||||
pub code_server_args: &'a CodeServerArgs,
|
||||
pub platform: Platform,
|
||||
pub shutdown: Barrier<ShutdownSignal>,
|
||||
pub log_broadcast: &'a BroadcastLogSink,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SingletonServerContext {
|
||||
log: log::Logger,
|
||||
shutdown_tx: broadcast::Sender<ShutdownSignal>,
|
||||
broadcast_tx: broadcast::Sender<Vec<u8>>,
|
||||
current_name: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
pub struct RpcServer {
|
||||
fut: JoinHandle<Result<(), CodeError>>,
|
||||
shutdown_broadcast: broadcast::Sender<ShutdownSignal>,
|
||||
current_name: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
pub fn make_singleton_server(
|
||||
log_broadcast: BroadcastLogSink,
|
||||
log: log::Logger,
|
||||
server: SingletonServer,
|
||||
shutdown_rx: Barrier<ShutdownSignal>,
|
||||
) -> RpcServer {
|
||||
let (shutdown_broadcast, _) = broadcast::channel(4);
|
||||
let rpc = new_json_rpc();
|
||||
|
||||
let current_name = Arc::new(Mutex::new(None));
|
||||
let mut rpc = rpc.methods(SingletonServerContext {
|
||||
log: log.clone(),
|
||||
shutdown_tx: shutdown_broadcast.clone(),
|
||||
broadcast_tx: log_broadcast.get_brocaster(),
|
||||
current_name: current_name.clone(),
|
||||
});
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_RESTART,
|
||||
|_: protocol::EmptyObject, ctx| {
|
||||
info!(ctx.log, "restarting tunnel after client request");
|
||||
let _ = ctx.shutdown_tx.send(ShutdownSignal::RpcRestartRequested);
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_STATUS,
|
||||
|_: protocol::EmptyObject, c| {
|
||||
Ok(protocol::singleton::Status {
|
||||
tunnel: match c.current_name.lock().unwrap().clone() {
|
||||
Some(name) => protocol::singleton::TunnelState::Connected { name },
|
||||
None => protocol::singleton::TunnelState::Disconnected,
|
||||
},
|
||||
})
|
||||
},
|
||||
);
|
||||
|
||||
rpc.register_sync(
|
||||
protocol::singleton::METHOD_SHUTDOWN,
|
||||
|_: protocol::EmptyObject, ctx| {
|
||||
info!(
|
||||
ctx.log,
|
||||
"closing tunnel and all clients after a shutdown request"
|
||||
);
|
||||
let _ = ctx.broadcast_tx.send(RpcCaller::serialize_notify(
|
||||
&JsonRpcSerializer {},
|
||||
protocol::singleton::METHOD_SHUTDOWN,
|
||||
protocol::EmptyObject {},
|
||||
));
|
||||
let _ = ctx.shutdown_tx.send(ShutdownSignal::RpcShutdownRequested);
|
||||
Ok(())
|
||||
},
|
||||
);
|
||||
|
||||
// we tokio spawn instead of keeping a future, since we want it to progress
|
||||
// even outside of the start_singleton_server loop (i.e. while the tunnel restarts)
|
||||
let fut = tokio::spawn(async move {
|
||||
serve_singleton_rpc(log_broadcast, server, rpc.build(log), shutdown_rx).await
|
||||
});
|
||||
RpcServer {
|
||||
shutdown_broadcast,
|
||||
current_name,
|
||||
fut,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_singleton_server<'a>(
|
||||
args: SingletonServerArgs<'_>,
|
||||
) -> Result<ServerTermination, AnyError> {
|
||||
let shutdown_rx = ShutdownRequest::create_rx([
|
||||
ShutdownRequest::Derived(Box::new(args.server.shutdown_broadcast.subscribe())),
|
||||
ShutdownRequest::Derived(Box::new(args.shutdown.clone())),
|
||||
]);
|
||||
|
||||
{
|
||||
print_listening(&args.log, &args.tunnel.name);
|
||||
let mut name = args.server.current_name.lock().unwrap();
|
||||
*name = Some(args.tunnel.name.clone())
|
||||
}
|
||||
|
||||
let serve_fut = super::serve(
|
||||
&args.log,
|
||||
args.tunnel,
|
||||
args.paths,
|
||||
args.code_server_args,
|
||||
args.platform,
|
||||
shutdown_rx,
|
||||
);
|
||||
|
||||
pin!(serve_fut);
|
||||
|
||||
match futures::future::select(Pin::new(&mut args.server.fut), &mut serve_fut).await {
|
||||
Either::Left((rpc_result, fut)) => {
|
||||
// the rpc server will only end as a result of a graceful shutdown, or
|
||||
// with an error. Return the result of the eventual shutdown of the
|
||||
// control server.
|
||||
rpc_result.unwrap()?;
|
||||
fut.await
|
||||
}
|
||||
Either::Right((ctrl_result, _)) => ctrl_result,
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_singleton_rpc<C: Clone + Send + Sync + 'static>(
|
||||
log_broadcast: BroadcastLogSink,
|
||||
mut server: SingletonServer,
|
||||
dispatcher: RpcDispatcher<JsonRpcSerializer, C>,
|
||||
shutdown_rx: Barrier<ShutdownSignal>,
|
||||
) -> Result<(), CodeError> {
|
||||
let mut own_shutdown = shutdown_rx.clone();
|
||||
let shutdown_fut = own_shutdown.wait();
|
||||
pin!(shutdown_fut);
|
||||
|
||||
loop {
|
||||
let cnx = tokio::select! {
|
||||
c = server.accept() => c?,
|
||||
_ = &mut shutdown_fut => return Ok(()),
|
||||
};
|
||||
|
||||
let (read, write) = socket_stream_split(cnx);
|
||||
let dispatcher = dispatcher.clone();
|
||||
let msg_rx = log_broadcast.replay_and_subscribe();
|
||||
let shutdown_rx = shutdown_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = start_json_rpc(dispatcher.clone(), read, write, msg_rx, shutdown_rx).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Log sink that can broadcast and replay log events. Used for transmitting
|
||||
/// logs from the singleton to all clients. This should be created and injected
|
||||
/// into other services, like the tunnel, before `start_singleton_server`
|
||||
/// is called.
|
||||
#[derive(Clone)]
|
||||
pub struct BroadcastLogSink {
|
||||
recent: Arc<Mutex<RingBuffer<Vec<u8>>>>,
|
||||
tx: broadcast::Sender<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl Default for BroadcastLogSink {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl BroadcastLogSink {
|
||||
pub fn new() -> Self {
|
||||
let (tx, _) = broadcast::channel(64);
|
||||
Self {
|
||||
tx,
|
||||
recent: Arc::new(Mutex::new(RingBuffer::new(50))),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_brocaster(&self) -> broadcast::Sender<Vec<u8>> {
|
||||
self.tx.clone()
|
||||
}
|
||||
|
||||
fn replay_and_subscribe(
|
||||
&self,
|
||||
) -> ConcatReceivable<Vec<u8>, mpsc::UnboundedReceiver<Vec<u8>>, broadcast::Receiver<Vec<u8>>> {
|
||||
let (log_replay_tx, log_replay_rx) = mpsc::unbounded_channel();
|
||||
|
||||
for log in self.recent.lock().unwrap().iter() {
|
||||
let _ = log_replay_tx.send(log.clone());
|
||||
}
|
||||
|
||||
let _ = log_replay_tx.send(RpcCaller::serialize_notify(
|
||||
&JsonRpcSerializer {},
|
||||
protocol::singleton::METHOD_LOG_REPLY_DONE,
|
||||
protocol::EmptyObject {},
|
||||
));
|
||||
|
||||
ConcatReceivable::new(log_replay_rx, self.tx.subscribe())
|
||||
}
|
||||
}
|
||||
|
||||
impl log::LogSink for BroadcastLogSink {
|
||||
fn write_log(&self, level: log::Level, prefix: &str, message: &str) {
|
||||
let s = JsonRpcSerializer {};
|
||||
let serialized = RpcCaller::serialize_notify(
|
||||
&s,
|
||||
protocol::singleton::METHOD_LOG,
|
||||
protocol::singleton::LogMessage {
|
||||
level: Some(level),
|
||||
prefix,
|
||||
message,
|
||||
},
|
||||
);
|
||||
|
||||
let _ = self.tx.send(serialized.clone());
|
||||
self.recent.lock().unwrap().push(serialized);
|
||||
}
|
||||
|
||||
fn write_result(&self, message: &str) {
|
||||
self.write_log(log::Level::Info, "", message);
|
||||
}
|
||||
}
|
||||
290
cli/src/tunnels/socket_signal.rs
Normal file
290
cli/src/tunnels/socket_signal.rs
Normal file
@@ -0,0 +1,290 @@
|
||||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the Source EULA. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::msgpack_rpc::MsgPackCaller;
|
||||
|
||||
use super::{
|
||||
protocol::{ClientRequestMethod, RefServerMessageParams, ToClientRequest},
|
||||
server_multiplexer::ServerMultiplexer,
|
||||
};
|
||||
|
||||
pub struct CloseReason(pub String);
|
||||
|
||||
pub enum SocketSignal {
|
||||
/// Signals bytes to send to the socket.
|
||||
Send(Vec<u8>),
|
||||
/// Closes the socket (e.g. as a result of an error)
|
||||
CloseWith(CloseReason),
|
||||
}
|
||||
|
||||
impl From<Vec<u8>> for SocketSignal {
|
||||
fn from(v: Vec<u8>) -> Self {
|
||||
SocketSignal::Send(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl SocketSignal {
|
||||
pub fn from_message<T>(msg: &T) -> Self
|
||||
where
|
||||
T: Serialize + ?Sized,
|
||||
{
|
||||
SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// todo@connor4312: cleanup once everything is moved to rpc standard interfaces
|
||||
#[allow(dead_code)]
|
||||
pub enum ServerMessageDestination {
|
||||
Channel(mpsc::Sender<SocketSignal>),
|
||||
Rpc(MsgPackCaller),
|
||||
}
|
||||
|
||||
/// Struct that handling sending or closing a connected server socket.
|
||||
pub struct ServerMessageSink {
|
||||
id: u16,
|
||||
tx: Option<ServerMessageDestination>,
|
||||
multiplexer: ServerMultiplexer,
|
||||
flate: Option<FlateStream<CompressFlateAlgorithm>>,
|
||||
}
|
||||
|
||||
impl ServerMessageSink {
|
||||
pub fn new_plain(
|
||||
multiplexer: ServerMultiplexer,
|
||||
id: u16,
|
||||
tx: ServerMessageDestination,
|
||||
) -> Self {
|
||||
Self {
|
||||
tx: Some(tx),
|
||||
id,
|
||||
multiplexer,
|
||||
flate: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_compressed(
|
||||
multiplexer: ServerMultiplexer,
|
||||
id: u16,
|
||||
tx: ServerMessageDestination,
|
||||
) -> Self {
|
||||
Self {
|
||||
tx: Some(tx),
|
||||
id,
|
||||
multiplexer,
|
||||
flate: Some(FlateStream::new(CompressFlateAlgorithm(
|
||||
flate2::Compress::new(flate2::Compression::new(2), false),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn server_message(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
|
||||
let id = self.id;
|
||||
let mut tx = self.tx.take().unwrap();
|
||||
let body = self.get_server_msg_content(body);
|
||||
let msg = RefServerMessageParams { i: id, body };
|
||||
|
||||
let r = match &mut tx {
|
||||
ServerMessageDestination::Channel(tx) => {
|
||||
tx.send(SocketSignal::from_message(&ToClientRequest {
|
||||
id: None,
|
||||
params: ClientRequestMethod::servermsg(msg),
|
||||
}))
|
||||
.await
|
||||
}
|
||||
ServerMessageDestination::Rpc(caller) => {
|
||||
caller.notify("servermsg", msg);
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
|
||||
self.tx = Some(tx);
|
||||
r
|
||||
}
|
||||
|
||||
pub(crate) fn get_server_msg_content<'a: 'b, 'b>(&'a mut self, body: &'b [u8]) -> &'b [u8] {
|
||||
if let Some(flate) = &mut self.flate {
|
||||
if let Ok(compressed) = flate.process(body) {
|
||||
return compressed;
|
||||
}
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ServerMessageSink {
|
||||
fn drop(&mut self) {
|
||||
self.multiplexer.remove(self.id);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ClientMessageDecoder {
|
||||
dec: Option<FlateStream<DecompressFlateAlgorithm>>,
|
||||
}
|
||||
|
||||
impl ClientMessageDecoder {
|
||||
pub fn new_plain() -> Self {
|
||||
ClientMessageDecoder { dec: None }
|
||||
}
|
||||
|
||||
pub fn new_compressed() -> Self {
|
||||
ClientMessageDecoder {
|
||||
dec: Some(FlateStream::new(DecompressFlateAlgorithm(
|
||||
flate2::Decompress::new(false),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> {
|
||||
match &mut self.dec {
|
||||
Some(d) => d.process(message),
|
||||
None => Ok(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait FlateAlgorithm {
|
||||
fn total_in(&self) -> u64;
|
||||
fn total_out(&self) -> u64;
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error>;
|
||||
}
|
||||
|
||||
struct DecompressFlateAlgorithm(flate2::Decompress);
|
||||
|
||||
impl FlateAlgorithm for DecompressFlateAlgorithm {
|
||||
fn total_in(&self) -> u64 {
|
||||
self.0.total_in()
|
||||
}
|
||||
|
||||
fn total_out(&self) -> u64 {
|
||||
self.0.total_out()
|
||||
}
|
||||
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error> {
|
||||
self.0
|
||||
.decompress(contents, output, flate2::FlushDecompress::None)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
}
|
||||
|
||||
struct CompressFlateAlgorithm(flate2::Compress);
|
||||
|
||||
impl FlateAlgorithm for CompressFlateAlgorithm {
|
||||
fn total_in(&self) -> u64 {
|
||||
self.0.total_in()
|
||||
}
|
||||
|
||||
fn total_out(&self) -> u64 {
|
||||
self.0.total_out()
|
||||
}
|
||||
|
||||
fn process(
|
||||
&mut self,
|
||||
contents: &[u8],
|
||||
output: &mut [u8],
|
||||
) -> Result<flate2::Status, std::io::Error> {
|
||||
self.0
|
||||
.compress(contents, output, flate2::FlushCompress::Sync)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
}
|
||||
|
||||
struct FlateStream<A>
|
||||
where
|
||||
A: FlateAlgorithm,
|
||||
{
|
||||
flate: A,
|
||||
output: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<A> FlateStream<A>
|
||||
where
|
||||
A: FlateAlgorithm,
|
||||
{
|
||||
pub fn new(alg: A) -> Self {
|
||||
Self {
|
||||
flate: alg,
|
||||
output: vec![0; 4096],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process(&mut self, contents: &[u8]) -> std::io::Result<&[u8]> {
|
||||
let mut out_offset = 0;
|
||||
let mut in_offset = 0;
|
||||
loop {
|
||||
let in_before = self.flate.total_in();
|
||||
let out_before = self.flate.total_out();
|
||||
|
||||
match self
|
||||
.flate
|
||||
.process(&contents[in_offset..], &mut self.output[out_offset..])
|
||||
{
|
||||
Ok(flate2::Status::Ok | flate2::Status::BufError) => {
|
||||
let processed_len = in_offset + (self.flate.total_in() - in_before) as usize;
|
||||
let output_len = out_offset + (self.flate.total_out() - out_before) as usize;
|
||||
if processed_len < contents.len() {
|
||||
// If we filled the output buffer but there's more data to compress,
|
||||
// extend the output buffer and keep compressing.
|
||||
out_offset = output_len;
|
||||
in_offset = processed_len;
|
||||
if output_len == self.output.len() {
|
||||
self.output.resize(self.output.len() * 2, 0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok(&self.output[..output_len]);
|
||||
}
|
||||
Ok(flate2::Status::StreamEnd) => {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"unexpected stream end",
|
||||
))
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// Note this useful idiom: importing names from outer (for mod tests) scope.
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_round_trips_compression() {
|
||||
let (tx, _) = mpsc::channel(1);
|
||||
let mut sink = ServerMessageSink::new_compressed(
|
||||
ServerMultiplexer::new(),
|
||||
0,
|
||||
ServerMessageDestination::Channel(tx),
|
||||
);
|
||||
let mut decompress = ClientMessageDecoder::new_compressed();
|
||||
|
||||
// 3000 and 30000 test resizing the buffer
|
||||
for msg_len in [3, 30, 300, 3000, 30000] {
|
||||
let vals = (0..msg_len).map(|v| v as u8).collect::<Vec<u8>>();
|
||||
let compressed = sink.get_server_msg_content(&vals);
|
||||
assert_ne!(compressed, vals);
|
||||
let decompressed = decompress.decode(compressed).unwrap();
|
||||
assert_eq!(decompressed.len(), vals.len());
|
||||
assert_eq!(decompressed, vals);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user